class dgl.nn.pytorch.factory.SegmentedKNNGraph(k)[source]

Bases: torch.nn.modules.module.Module

Layer that transforms one point set into a graph, or a batch of point sets with different number of points into a union of those graphs.

If a batch of point sets is provided, then the point \(j\) in the point set \(i\) is mapped to graph node ID: \(\sum_{p<i} |V_p| + j\), where \(|V_p|\) means the number of points in the point set \(p\).

The predecessors of each node are the k-nearest neighbors of the corresponding point.


k (int) – The number of neighbors.


The nearest neighbors found for a node include the node itself.


The following example uses PyTorch backend.

>>> import torch
>>> from dgl.nn.pytorch.factory import SegmentedKNNGraph
>>> kg = SegmentedKNNGraph(2)
>>> x = torch.tensor([[0,1],
...                   [1,2],
...                   [1,3],
...                   [100, 101],
...                   [101, 102],
...                   [50, 50],
...                   [24,25],
...                   [25,24]])
>>> g = kg(x, [3,3,2])
>>> print(g.edges())
(tensor([0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7]),
 tensor([0, 0, 1, 2, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 6, 7]))
forward(x, segs, algorithm='bruteforce-blas', dist='euclidean')[source]

Forward computation.

  • x (Tensor) – \((M, D)\) where \(M\) means the total number of points in all point sets, and \(D\) means the size of features.

  • segs (iterable of int) – \((N)\) integers where \(N\) means the number of point sets. The number of elements must sum up to \(M\). And any \(N\) should \(\ge k\)

  • algorithm (str, optional) –

    Algorithm used to compute the k-nearest neighbors.

    • ’bruteforce-blas’ will first compute the distance matrix using BLAS matrix multiplication operation provided by backend frameworks. Then use topk algorithm to get k-nearest neighbors. This method is fast when the point set is small but has \(O(N^2)\) memory complexity where \(N\) is the number of points.

    • ’bruteforce’ will compute distances pair by pair and directly select the k-nearest neighbors during distance computation. This method is slower than ‘bruteforce-blas’ but has less memory overhead (i.e., \(O(Nk)\) where \(N\) is the number of points, \(k\) is the number of nearest neighbors per node) since we do not need to store all distances.

    • ’bruteforce-sharemem’ (CUDA only) is similar to ‘bruteforce’ but use shared memory in CUDA devices for buffer. This method is faster than ‘bruteforce’ when the dimension of input points is not large. This method is only available on CUDA device.

    • ’kd-tree’ will use the kd-tree algorithm (CPU only). This method is suitable for low-dimensional data (e.g. 3D point clouds)

    • ’nn-descent’ is a approximate approach from paper Efficient k-nearest neighbor graph construction for generic similarity measures. This method will search for nearest neighbor candidates in “neighbors’ neighbors”.

    (default: ‘bruteforce-blas’)

  • dist (str, optional) –

    The distance metric used to compute distance between points. It can be the following metrics: * ‘euclidean’: Use Euclidean distance (L2 norm)

    \(\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}\).

    • ’cosine’: Use cosine distance.

    (default: ‘euclidean’)


A DGLGraph without features.

Return type