# dgl.segmented_knn_graph¶

dgl.segmented_knn_graph(x, k, segs, algorithm='topk')[source]

Construct multiple graphs from multiple sets of points according to k-nearest-neighbor (KNN) and return.

Compared with dgl.knn_graph(), this allows multiple point sets with different capacity. The points from different sets are stored contiguously in the x tensor. segs specifies the number of points in each point set. The function constructs a KNN graph for each point set, where the predecessors of each point are its k-nearest neighbors measured by the Euclidean distance. DGL then composes all KNN graphs into a graph with multiple connected components.

Parameters
• x (Tensor) – Coordinates/features of points. Must be 2D. It can be either on CPU or GPU.

• k (int) – The number of nearest neighbors per node.

• segs (list[int]) – Number of points in each point set. The numbers in segs must sum up to the number of rows in x.

• algorithm (str, optional) –

Algorithm used to compute the k-nearest neighbors.

• ’topk’ will use topk algorithm (quick-select or sorting, depending on backend implementation)

• ’kd-tree’ will use kd-tree algorithm (only on cpu)

(default: ‘topk’)

Returns

The graph. The node IDs are in the same order as x.

If using the ‘topk’ algorithm, the returned graph is on the same device as input x. Else, the returned graph is on CPU, regardless of the context of the input x.

Return type

DGLGraph

Examples

The following examples use PyTorch backend.

>>> import dgl
>>> import torch


In the example below, the first point set has three points and the second point set has four points.

>>> # Features/coordinates of the first point set
>>> x1 = torch.tensor([[0.0, 0.5, 0.2],
...                    [0.1, 0.3, 0.2],
...                    [0.4, 0.2, 0.2]])
>>> # Features/coordinates of the second point set
>>> x2 = torch.tensor([[0.3, 0.2, 0.1],
...                    [0.5, 0.2, 0.3],
...                    [0.1, 0.1, 0.2],
...                    [0.6, 0.3, 0.3]])
>>> x = torch.cat([x1, x2], dim=0)
>>> segs = [x1.shape[0], x2.shape[0]]
>>> knn_g = dgl.segmented_knn_graph(x, 2, segs)
>>> knn_g.edges()
(tensor([0, 0, 1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6]),
tensor([0, 1, 0, 1, 2, 2, 3, 5, 4, 6, 3, 5, 4, 6]))