dgl.knn_graph¶
-
dgl.
knn_graph
(x, k)[source]¶ Construct a graph from a set of points according to k-nearest-neighbor (KNN) and return.
The function transforms the coordinates/features of a point set into a directed homogeneous graph. The coordinates of the point set is specified as a matrix whose rows correspond to points and columns correspond to coordinate/feature dimensions.
The nodes of the returned graph correspond to the points, where the predecessors of each point are its k-nearest neighbors measured by the Euclidean distance.
If
x
is a 3D tensor, then each submatrix will be transformed into a separate graph. DGL then composes the graphs into a large graph of multiple connected components.- Parameters
x (Tensor) –
The point coordinates. It can be either on CPU or GPU.
If is 2D,
x[i]
corresponds to the i-th node in the KNN graph.If is 3D,
x[i]
corresponds to the i-th KNN graph andx[i][j]
corresponds to the j-th node in the i-th KNN graph.
k (int) – The number of nearest neighbors per node.
- Returns
The constructred graph. The node IDs are in the same order as
x
.The returned graph is on CPU, regardless of the context of input
x
.- Return type
Examples
The following examples use PyTorch backend.
>>> import dgl >>> import torch
When
x
is a 2D tensor, a single KNN graph is constructed.>>> x = torch.tensor([[0.0, 0.0, 1.0], ... [1.0, 0.5, 0.5], ... [0.5, 0.2, 0.2], ... [0.3, 0.2, 0.4]]) >>> knn_g = dgl.knn_graph(x, 2) # Each node has two predecessors >>> knn_g.edges() >>> (tensor([0, 1, 2, 2, 2, 3, 3, 3]), tensor([0, 1, 1, 2, 3, 0, 2, 3]))
When
x
is a 3D tensor, DGL constructs multiple KNN graphs and and then composes them into a graph of multiple connected components.>>> x1 = torch.tensor([[0.0, 0.0, 1.0], ... [1.0, 0.5, 0.5], ... [0.5, 0.2, 0.2], ... [0.3, 0.2, 0.4]]) >>> x2 = torch.tensor([[0.0, 1.0, 1.0], ... [0.3, 0.3, 0.3], ... [0.4, 0.4, 1.0], ... [0.3, 0.8, 0.2]]) >>> x = torch.stack([x1, x2], dim=0) >>> knn_g = dgl.knn_graph(x, 2) # Each node has two predecessors >>> knn_g.edges() (tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7]), tensor([0, 1, 1, 2, 3, 0, 2, 3, 4, 5, 6, 7, 4, 6, 5, 7]))