dgl.segmented_knn_graph¶
-
dgl.
segmented_knn_graph
(x, k, segs)[source]¶ Construct multiple graphs from multiple sets of points according to k-nearest-neighbor (KNN) and return.
Compared with
dgl.knn_graph()
, this allows multiple point sets with different capacity. The points from different sets are stored contiguously in thex
tensor.segs
specifies the number of points in each point set. The function constructs a KNN graph for each point set, where the predecessors of each point are its k-nearest neighbors measured by the Euclidean distance. DGL then composes all KNN graphs into a graph with multiple connected components.- Parameters
- Returns
The graph. The node IDs are in the same order as
x
.The returned graph is on CPU, regardless of the context of input
x
.- Return type
DGLGraph
Examples
The following examples use PyTorch backend.
>>> import dgl >>> import torch
In the example below, the first point set has three points and the second point set has four points.
>>> # Features/coordinates of the first point set >>> x1 = torch.tensor([[0.0, 0.5, 0.2], ... [0.1, 0.3, 0.2], ... [0.4, 0.2, 0.2]]) >>> # Features/coordinates of the second point set >>> x2 = torch.tensor([[0.3, 0.2, 0.1], ... [0.5, 0.2, 0.3], ... [0.1, 0.1, 0.2], ... [0.6, 0.3, 0.3]]) >>> x = torch.cat([x1, x2], dim=0) >>> segs = [x1.shape[0], x2.shape[0]] >>> knn_g = dgl.segmented_knn_graph(x, 2, segs) >>> knn_g.edges() (tensor([0, 0, 1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6]), tensor([0, 1, 0, 1, 2, 2, 3, 5, 4, 6, 3, 5, 4, 6]))