Source code for dgl.geometry.fps

"""Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name

from .. import backend as F

from ..base import DGLError
from .capi import _farthest_point_sampler

__all__ = ['farthest_point_sampler']


[docs]def farthest_point_sampler(pos, npoints, start_idx=None): """Farthest Point Sampler without the need to compute all pairs of distance. In each batch, the algorithm starts with the sample index specified by ``start_idx``. Then for each point, we maintain the minimum to-sample distance. Finally, we pick the point with the maximum such distance. This process will be repeated for ``sample_points`` - 1 times. Parameters ---------- pos : tensor The positional tensor of shape (B, N, C) npoints : int The number of points to sample in each batch. start_idx : int, optional If given, appoint the index of the starting point, otherwise randomly select a point as the start point. (default: None) Returns ------- tensor of shape (B, npoints) The sampled indices in each batch. Examples -------- The following exmaple uses PyTorch backend. >>> import torch >>> from dgl.geometry import farthest_point_sampler >>> x = torch.rand((2, 10, 3)) >>> point_idx = farthest_point_sampler(x, 2) >>> print(point_idx) tensor([[5, 6], [7, 8]]) """ ctx = F.context(pos) B, N, C = pos.shape pos = pos.reshape(-1, C) dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx) if start_idx is None: start_idx = F.randint(shape=(B, ), dtype=F.int64, ctx=ctx, low=0, high=N-1) else: if start_idx >= N or start_idx < 0: raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format( N, start_idx)) start_idx = F.full_1d(B, start_idx, dtype=F.int64, ctx=ctx) result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx) _farthest_point_sampler(pos, B, npoints, dist, start_idx, result) return result.reshape(B, npoints)