# Source code for dgl.geometry.fps

"""Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name

from .. import backend as F

from ..base import DGLError
from .capi import _farthest_point_sampler

__all__ = ['farthest_point_sampler']

[docs]def farthest_point_sampler(pos, npoints, start_idx=None):
"""Farthest Point Sampler without the need to compute all pairs of distance.

In each batch, the algorithm starts with the sample index specified by start_idx.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for sample_points - 1 times.

Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
npoints : int
The number of points to sample in each batch.
start_idx : int, optional
If given, appoint the index of the starting point,
otherwise randomly select a point as the start point.
(default: None)

Returns
-------
tensor of shape (B, npoints)
The sampled indices in each batch.

Examples
--------
The following exmaple uses PyTorch backend.

>>> import torch
>>> from dgl.geometry import farthest_point_sampler
>>> x = torch.rand((2, 10, 3))
>>> point_idx = farthest_point_sampler(x, 2)
>>> print(point_idx)
tensor([[5, 6],
[7, 8]])
"""
ctx = F.context(pos)
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx)
if start_idx is None:
start_idx = F.randint(shape=(B, ), dtype=F.int64,
ctx=ctx, low=0, high=N-1)
else:
if start_idx >= N or start_idx < 0:
raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx))
start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx)
result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx)
_farthest_point_sampler(pos, B, npoints, dist, start_idx, result)
return result.reshape(B, npoints)