Source code for dgl.nn.pytorch.factory

"""Modules that transforms between graphs and between graph and tensors."""
import torch.nn as nn
from ...transform import knn_graph, segmented_knn_graph

def pairwise_squared_distance(x):
    '''
    x : (n_samples, n_points, dims)
    return : (n_samples, n_points, n_points)
    '''
    x2s = (x * x).sum(-1, keepdim=True)
    return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)


[docs]class KNNGraph(nn.Module): r"""Layer that transforms one point set into a graph, or a batch of point sets with the same number of points into a union of those graphs. If a batch of point set is provided, then the point :math:`j` in point set :math:`i` is mapped to graph node ID :math:`i \times M + j`, where :math:`M` is the number of nodes in each point set. The predecessors of each node are the k-nearest neighbors of the corresponding point. Parameters ---------- k : int The number of neighbors """ def __init__(self, k): super(KNNGraph, self).__init__() self.k = k #pylint: disable=invalid-name
[docs] def forward(self, x): """Forward computation. Parameters ---------- x : Tensor :math:`(M, D)` or :math:`(N, M, D)` where :math:`N` means the number of point sets, :math:`M` means the number of points in each point set, and :math:`D` means the size of features. Returns ------- DGLGraph A DGLGraph with no features. """ return knn_graph(x, self.k)
[docs]class SegmentedKNNGraph(nn.Module): r"""Layer that transforms one point set into a graph, or a batch of point sets with different number of points into a union of those graphs. If a batch of point set is provided, then the point :math:`j` in point set :math:`i` is mapped to graph node ID :math:`\sum_{p<i} |V_p| + j`, where :math:`|V_p|` means the number of points in point set :math:`p`. The predecessors of each node are the k-nearest neighbors of the corresponding point. Parameters ---------- k : int The number of neighbors """ def __init__(self, k): super(SegmentedKNNGraph, self).__init__() self.k = k #pylint: disable=invalid-name
[docs] def forward(self, x, segs): """Forward computation. Parameters ---------- x : Tensor :math:`(M, D)` where :math:`M` means the total number of points in all point sets. segs : iterable of int :math:`(N)` integers where :math:`N` means the number of point sets. The elements must sum up to :math:`M`. Returns ------- DGLGraph A DGLGraph with no features. """ return segmented_knn_graph(x, self.k, segs)