# Source code for dgl.nn.pytorch.factory

"""Modules that transforms between graphs and between graph and tensors."""
import torch.nn as nn
from ...transform import knn_graph, segmented_knn_graph

def pairwise_squared_distance(x):
'''
x : (n_samples, n_points, dims)
return : (n_samples, n_points, n_points)
'''
x2s = (x * x).sum(-1, keepdim=True)
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)

[docs]class KNNGraph(nn.Module):
r"""

Description
-----------
Layer that transforms one point set into a graph, or a batch of
point sets with the same number of points into a union of those graphs.

The KNNGraph is implemented in the following steps:

1. Compute an NxN matrix of pairwise distance for all points.
2. Pick the k points with the smallest distance for each point as their k-nearest neighbors.
3. Construct a graph with edges to each point as a node from its k-nearest neighbors.

The overall computational complexity is :math:O(N^2(logN + D).

If a batch of point sets is provided, the point :math:j in point
set :math:i is mapped to graph node ID: :math:i \times M + j, where
:math:M is the number of nodes in each point set.

The predecessors of each node are the k-nearest neighbors of the
corresponding point.

Parameters
----------
k : int
The number of neighbors.

Notes
-----
The nearest neighbors found for a node include the node itself.

Examples
--------
The following example uses PyTorch backend.

>>> import torch
>>> from dgl.nn.pytorch.factory import KNNGraph
>>>
>>> kg = KNNGraph(2)
>>> x = torch.tensor([[0,1],
[1,2],
[1,3],
[100, 101],
[101, 102],
[50, 50]])
>>> g = kg(x)
>>> print(g.edges())
(tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]),
tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
"""
def __init__(self, k):
super(KNNGraph, self).__init__()
self.k = k

#pylint: disable=invalid-name
[docs]    def forward(self, x, algorithm='bruteforce-blas', dist='euclidean'):
r"""

Forward computation.

Parameters
----------
x : Tensor
:math:(M, D) or :math:(N, M, D) where :math:N means the
number of point sets, :math:M means the number of points in
each point set, and :math:D means the size of features.
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.

* 'bruteforce-blas' will first compute the distance matrix
using BLAS matrix multiplication operation provided by
backend frameworks. Then use topk algorithm to get
k-nearest neighbors. This method is fast when the point
set is small but has :math:O(N^2) memory complexity where
:math:N is the number of points.

* 'bruteforce' will compute distances pair by pair and
directly select the k-nearest neighbors during distance
computation. This method is slower than 'bruteforce-blas'
but has less memory overhead (i.e., :math:O(Nk) where :math:N
is the number of points, :math:k is the number of nearest
neighbors per node) since we do not need to store all distances.

* 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'
but use shared memory in CUDA devices for buffer. This method is
faster than 'bruteforce' when the dimension of input points
is not large. This method is only available on CUDA device.

* 'kd-tree' will use the kd-tree algorithm (CPU only).
This method is suitable for low-dimensional data (e.g. 3D
point clouds)

* 'nn-descent' is a approximate approach from paper
Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".

(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm)
:math:\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}.
* 'cosine': Use cosine distance.
(default: 'euclidean')

Returns
-------
DGLGraph
A DGLGraph without features.
"""
return knn_graph(x, self.k, algorithm=algorithm, dist=dist)

[docs]class SegmentedKNNGraph(nn.Module):
r"""

Description
-----------
Layer that transforms one point set into a graph, or a batch of
point sets with different number of points into a union of those graphs.

If a batch of point sets is provided, then the point :math:j in the point
set :math:i is mapped to graph node ID:
:math:\sum_{p<i} |V_p| + j, where :math:|V_p| means the number of
points in the point set :math:p.

The predecessors of each node are the k-nearest neighbors of the
corresponding point.

Parameters
----------
k : int
The number of neighbors.

Notes
-----
The nearest neighbors found for a node include the node itself.

Examples
--------
The following example uses PyTorch backend.

>>> import torch
>>> from dgl.nn.pytorch.factory import SegmentedKNNGraph
>>>
>>> kg = SegmentedKNNGraph(2)
>>> x = torch.tensor([[0,1],
...                   [1,2],
...                   [1,3],
...                   [100, 101],
...                   [101, 102],
...                   [50, 50],
...                   [24,25],
...                   [25,24]])
>>> g = kg(x, [3,3,2])
>>> print(g.edges())
(tensor([0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7]),
tensor([0, 0, 1, 2, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 6, 7]))
>>>

"""
def __init__(self, k):
super(SegmentedKNNGraph, self).__init__()
self.k = k

#pylint: disable=invalid-name
[docs]    def forward(self, x, segs, algorithm='bruteforce-blas', dist='euclidean'):
r"""Forward computation.

Parameters
----------
x : Tensor
:math:(M, D) where :math:M means the total number of points
in all point sets, and :math:D means the size of features.
segs : iterable of int
:math:(N) integers where :math:N means the number of point
sets.  The number of elements must sum up to :math:M. And any
:math:N should :math:\ge k
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.

* 'bruteforce-blas' will first compute the distance matrix
using BLAS matrix multiplication operation provided by
backend frameworks. Then use topk algorithm to get
k-nearest neighbors. This method is fast when the point
set is small but has :math:O(N^2) memory complexity where
:math:N is the number of points.

* 'bruteforce' will compute distances pair by pair and
directly select the k-nearest neighbors during distance
computation. This method is slower than 'bruteforce-blas'
but has less memory overhead (i.e., :math:O(Nk) where :math:N
is the number of points, :math:k is the number of nearest
neighbors per node) since we do not need to store all distances.

* 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'
but use shared memory in CUDA devices for buffer. This method is
faster than 'bruteforce' when the dimension of input points
is not large. This method is only available on CUDA device.

* 'kd-tree' will use the kd-tree algorithm (CPU only).
This method is suitable for low-dimensional data (e.g. 3D
point clouds)

* 'nn-descent' is a approximate approach from paper
Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".

(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm)
:math:\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}.
* 'cosine': Use cosine distance.
(default: 'euclidean')

Returns
-------
DGLGraph
A DGLGraph without features.
"""

return segmented_knn_graph(x, self.k, segs, algorithm=algorithm, dist=dist)