NeighborSampler

class dgl.dataloading.NeighborSampler(fanouts, edge_dir='in', prob=None, replace=False, prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None)[source]

Bases: dgl.dataloading.base.BlockSampler

Sampler that builds computational dependency of node representations via neighbor sampling for multilayer GNN.

This sampler will make every node gather messages from a fixed number of neighbors per edge type. The neighbors are picked uniformly.

Parameters
  • fanouts (list[int] or list[dict[etype, int]]) –

    List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer.

    If only a single integer is provided, DGL assumes that every edge type will have the same fanout.

    If -1 is provided for one edge type on one layer, then all inbound edges of that edge type will be included.

  • edge_dir (str, default 'in') – Can be either 'in' `` where the neighbors will be sampled according to incoming edges, or ``'out' otherwise, same as dgl.sampling.sample_neighbors().

  • prob (str, optional) – If given, the probability of each neighbor being sampled is proportional to the edge feature value with the given name in g.edata. The feature must be a scalar on each edge.

  • replace (bool, default False) – Whether to sample with replacement

  • prefetch_node_feats (list[str] or dict[ntype, list[str]], optional) – The source node data to prefetch for the first MFG, corresponding to the input node features necessary for the first GNN layer.

  • prefetch_labels (list[str] or dict[ntype, list[str]], optional) – The destination node data to prefetch for the last MFG, corresponding to the node labels of the minibatch.

  • prefetch_edge_feats (list[str] or dict[etype, list[str]], optional) – The edge data names to prefetch for all the MFGs, corresponding to the edge features necessary for all GNN layers.

  • output_device (device, optional) – The device of the output subgraphs or MFGs. Default is the same as the minibatch of seed nodes.

Examples

Node classification

To train a 3-layer GNN for node classification on a set of nodes train_nid on a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for the first, second, and third layer respectively (assuming the backend is PyTorch):

>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_nid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
...     train_on(blocks)

If training on a heterogeneous graph and you want different number of neighbors for each edge type, one should instead provide a list of dicts. Each dict would specify the number of neighbors to pick per edge type.

>>> sampler = dgl.dataloading.NeighborSampler([
...     {('user', 'follows', 'user'): 5,
...      ('user', 'plays', 'game'): 4,
...      ('game', 'played-by', 'user'): 3}] * 3)

If you would like non-uniform neighbor sampling:

>>> g.edata['p'] = torch.rand(g.num_edges())   # any non-negative 1D vector works
>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p')

Edge classification and link prediction

This class can also work for edge classification and link prediction together with as_edge_prediction_sampler().

>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)

See the documentation as_edge_prediction_sampler() for more details.

Notes

For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.