dgl.knn_graph

dgl.knn_graph(x, k)[source]

Construct a graph from a set of points according to k-nearest-neighbor (KNN) and return.

The function transforms the coordinates/features of a point set into a directed homogeneous graph. The coordinates of the point set is specified as a matrix whose rows correspond to points and columns correspond to coordinate/feature dimensions.

The nodes of the returned graph correspond to the points, where the predecessors of each point are its k-nearest neighbors measured by the Euclidean distance.

If x is a 3D tensor, then each submatrix will be transformed into a separate graph. DGL then composes the graphs into a large graph of multiple connected components.

Parameters
  • x (Tensor) –

    The point coordinates. It can be either on CPU or GPU.

    • If is 2D, x[i] corresponds to the i-th node in the KNN graph.

    • If is 3D, x[i] corresponds to the i-th KNN graph and x[i][j] corresponds to the j-th node in the i-th KNN graph.

  • k (int) – The number of nearest neighbors per node.

Returns

The constructred graph. The node IDs are in the same order as x.

The returned graph is on CPU, regardless of the context of input x.

Return type

DGLGraph

Examples

The following examples use PyTorch backend.

>>> import dgl
>>> import torch

When x is a 2D tensor, a single KNN graph is constructed.

>>> x = torch.tensor([[0.0, 0.0, 1.0],
...                   [1.0, 0.5, 0.5],
...                   [0.5, 0.2, 0.2],
...                   [0.3, 0.2, 0.4]])
>>> knn_g = dgl.knn_graph(x, 2)  # Each node has two predecessors
>>> knn_g.edges()
>>> (tensor([0, 1, 2, 2, 2, 3, 3, 3]), tensor([0, 1, 1, 2, 3, 0, 2, 3]))

When x is a 3D tensor, DGL constructs multiple KNN graphs and and then composes them into a graph of multiple connected components.

>>> x1 = torch.tensor([[0.0, 0.0, 1.0],
...                    [1.0, 0.5, 0.5],
...                    [0.5, 0.2, 0.2],
...                    [0.3, 0.2, 0.4]])
>>> x2 = torch.tensor([[0.0, 1.0, 1.0],
...                    [0.3, 0.3, 0.3],
...                    [0.4, 0.4, 1.0],
...                    [0.3, 0.8, 0.2]])
>>> x = torch.stack([x1, x2], dim=0)
>>> knn_g = dgl.knn_graph(x, 2)  # Each node has two predecessors
>>> knn_g.edges()
(tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7]),
 tensor([0, 1, 1, 2, 3, 0, 2, 3, 4, 5, 6, 7, 4, 6, 5, 7]))