# Source code for dgl.sampling.pinsage

"""PinSAGE sampler & related functions and classes"""

import numpy as np
from .._ffi.function import _init_api

from .. import backend as F
from .. import convert
from .randomwalks import random_walk
from .. import utils

def _select_pinsage_neighbors(src, dst, num_samples_per_node, k):
"""Determine the neighbors for PinSAGE algorithm from the given random walk traces.

This is fusing to_simple(), select_topk(), and counting the number of occurrences
together.
"""
src = F.to_dgl_nd(src)
dst = F.to_dgl_nd(dst)
src, dst, counts = _CAPI_DGLSamplingSelectPinSageNeighbors(src, dst, num_samples_per_node, k)
src = F.from_dgl_nd(src)
dst = F.from_dgl_nd(dst)
counts = F.from_dgl_nd(counts)
return (src, dst, counts)

class RandomWalkNeighborSampler(object):
"""PinSage-like neighbor sampler extended to any heterogeneous graphs.

Given a heterogeneous graph and a list of nodes, this callable will generate a homogeneous
graph where the neighbors of each given node are the most commonly visited nodes of the
same type by multiple random walks starting from that given node.  Each random walk consists
of multiple metapath-based traversals, with a probability of termination after each traversal.

The edges of the returned homogeneous graph will connect to the given nodes from their most
commonly visited nodes, with a feature indicating the number of visits.

The metapath must have the same beginning and ending node type to make the algorithm work.

This is a generalization of PinSAGE sampler which only works on bidirectional bipartite
graphs.

Parameters
----------
G : DGLGraph
The graph.
num_traversals : int
The maximum number of metapath-based traversals for a single random walk.

Usually considered a hyperparameter.
termination_prob : float
Termination probability after each metapath-based traversal.

Usually considered a hyperparameter.
num_random_walks : int
Number of random walks to try for each given node.

Usually considered a hyperparameter.
num_neighbors : int
Number of neighbors (or most commonly visited nodes) to select for each given node.
metapath : list[str] or list[tuple[str, str, str]], optional
The metapath.

If not given, DGL assumes that the graph is homogeneous and the metapath consists
of one step over the single edge type.
weight_column : str, default "weights"
The name of the edge feature to be stored on the returned graph with the number of
visits.

Examples
--------
See examples in :any:PinSAGESampler.
"""
def __init__(self, G, num_traversals, termination_prob,
num_random_walks, num_neighbors, metapath=None, weight_column='weights'):
self.G = G
self.weight_column = weight_column
self.num_random_walks = num_random_walks
self.num_neighbors = num_neighbors
self.num_traversals = num_traversals

if metapath is None:
if len(G.ntypes) > 1 or len(G.etypes) > 1:
raise ValueError('Metapath must be specified if the graph is homogeneous.')
metapath = [G.canonical_etypes[0]]
start_ntype = G.to_canonical_etype(metapath[0])[0]
end_ntype = G.to_canonical_etype(metapath[-1])[-1]
if start_ntype != end_ntype:
raise ValueError('The metapath must start and end at the same node type.')
self.ntype = start_ntype

self.metapath_hops = len(metapath)
self.metapath = metapath
self.full_metapath = metapath * num_traversals
restart_prob = np.zeros(self.metapath_hops * num_traversals)
restart_prob[self.metapath_hops::self.metapath_hops] = termination_prob
restart_prob = F.tensor(restart_prob, dtype=F.float32)
self.restart_prob = F.copy_to(restart_prob, G.device)

# pylint: disable=no-member
def __call__(self, seed_nodes):
"""
Parameters
----------
seed_nodes : Tensor
A tensor of given node IDs of node type ntype to generate neighbors from.  The
node type ntype is the beginning and ending node type of the given metapath.

It must be on CPU and have the same dtype as the ID type of the graph.

Returns
-------
g : DGLGraph
A homogeneous graph constructed by selecting neighbors for each given node according
to the algorithm above.  The returned graph is on CPU.
"""
seed_nodes = utils.prepare_tensor(self.G, seed_nodes, 'seed_nodes')
self.restart_prob = F.copy_to(self.restart_prob, F.context(seed_nodes))

seed_nodes = F.repeat(seed_nodes, self.num_random_walks, 0)
paths, _ = random_walk(
self.G, seed_nodes, metapath=self.full_metapath, restart_prob=self.restart_prob)
src = F.reshape(paths[:, self.metapath_hops::self.metapath_hops], (-1,))
dst = F.repeat(paths[:, 0], self.num_traversals, 0)

src, dst, counts = _select_pinsage_neighbors(
src, dst, (self.num_random_walks * self.num_traversals), self.num_neighbors)
neighbor_graph = convert.heterograph(
{(self.ntype, '_E', self.ntype): (src, dst)},
{self.ntype: self.G.number_of_nodes(self.ntype)}
)
neighbor_graph.edata[self.weight_column] = counts

return neighbor_graph

[docs]class PinSAGESampler(RandomWalkNeighborSampler):
"""PinSAGE-like neighbor sampler.

This callable works on a bidirectional bipartite graph with edge types
(ntype, fwtype, other_type) and (other_type, bwtype, ntype) (where ntype,
fwtype, bwtype and other_type could be arbitrary type names).  It will generate
a homogeneous graph of node type ntype where the neighbors of each given node are the
most commonly visited nodes of the same type by multiple random walks starting from that
given node.  Each random walk consists of multiple metapath-based traversals, with a
probability of termination after each traversal.  The metapath is always [fwtype, bwtype],
walking from node type ntype to node type other_type then back to ntype.

The edges of the returned homogeneous graph will connect to the given nodes from their most
commonly visited nodes, with a feature indicating the number of visits.

Parameters
----------
G : DGLGraph
The bidirectional bipartite graph.

The graph should only have two node types: ntype and other_type.
The graph should only have two edge types, one connecting from ntype to
other_type, and another connecting from other_type to ntype.
ntype : str
The node type for which the graph would be constructed on.
other_type : str
The other node type.
num_traversals : int
The maximum number of metapath-based traversals for a single random walk.

Usually considered a hyperparameter.
termination_prob : int
Termination probability after each metapath-based traversal.

Usually considered a hyperparameter.
num_random_walks : int
Number of random walks to try for each given node.

Usually considered a hyperparameter.
num_neighbors : int
Number of neighbors (or most commonly visited nodes) to select for each given node.
weight_column : str, default "weights"
The name of the edge feature to be stored on the returned graph with the number of
visits.

Examples
--------
Generate a random bidirectional bipartite graph with 3000 "A" nodes and 5000 "B" nodes.

>>> g = scipy.sparse.random(3000, 5000, 0.003)
>>> G = dgl.heterograph({
...     ('A', 'AB', 'B'): g.nonzero(),
...     ('B', 'BA', 'A'): g.T.nonzero()})

Then we create a PinSage neighbor sampler that samples a graph of node type "A".  Each
node would have (a maximum of) 10 neighbors.

>>> sampler = dgl.sampling.PinSAGESampler(G, 'A', 'B', 3, 0.5, 200, 10)

This is how we select the neighbors for node #0, #1 and #2 of type "A" according to
PinSAGE algorithm:

>>> seeds = torch.LongTensor([0, 1, 2])
>>> frontier = sampler(seeds)
>>> frontier.all_edges(form='uv')
(tensor([ 230,    0,  802,   47,   50, 1639, 1533,  406, 2110, 2687, 2408, 2823,
0,  972, 1230, 1658, 2373, 1289, 1745, 2918, 1818, 1951, 1191, 1089,
1282,  566, 2541, 1505, 1022,  812]),
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2]))

For an end-to-end example of PinSAGE model, including sampling on multiple layers
and computing with the sampled graphs, please refer to our PinSage example
in examples/pytorch/pinsage.

References
----------
Graph Convolutional Neural Networks for Web-Scale Recommender Systems
Ying et al., 2018, https://arxiv.org/abs/1806.01973
"""
[docs]    def __init__(self, G, ntype, other_type, num_traversals, termination_prob,
num_random_walks, num_neighbors, weight_column='weights'):
metagraph = G.metagraph()
fw_etype = list(metagraph[ntype][other_type])[0]
bw_etype = list(metagraph[other_type][ntype])[0]
super().__init__(G, num_traversals,
termination_prob, num_random_walks, num_neighbors,
metapath=[fw_etype, bw_etype], weight_column=weight_column)

_init_api('dgl.sampling.pinsage', __name__)