dgl.dataloading.NeighborSampler¶
-
class
dgl.dataloading.
NeighborSampler
(fanouts, edge_dir='in', prob=None, replace=False, prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None)[source]¶ Sampler that builds computational dependency of node representations via neighbor sampling for multilayer GNN.
This sampler will make every node gather messages from a fixed number of neighbors per edge type. The neighbors are picked uniformly.
- Parameters
fanouts (list[int] or list[dict[etype, int]]) –
List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer.
If only a single integer is provided, DGL assumes that every edge type will have the same fanout.
If -1 is provided for one edge type on one layer, then all inbound edges of that edge type will be included.
edge_dir (str, default
'in'
) – Can be either'in' `` where the neighbors will be sampled according to incoming edges, or ``'out'
otherwise, same asdgl.sampling.sample_neighbors()
.prob (str, optional) – If given, the probability of each neighbor being sampled is proportional to the edge feature value with the given name in
g.edata
. The feature must be a scalar on each edge.replace (bool, default False) – Whether to sample with replacement
prefetch_node_feats (list[str] or dict[ntype, list[str]], optional) – The source node data to prefetch for the first MFG, corresponding to the input node features necessary for the first GNN layer.
prefetch_labels (list[str] or dict[ntype, list[str]], optional) – The destination node data to prefetch for the last MFG, corresponding to the node labels of the minibatch.
prefetch_edge_feats (list[str] or dict[etype, list[str]], optional) – The edge data names to prefetch for all the MFGs, corresponding to the edge features necessary for all GNN layers.
output_device (device, optional) – The device of the output subgraphs or MFGs. Default is the same as the minibatch of seed nodes.
Examples
Node classification
To train a 3-layer GNN for node classification on a set of nodes
train_nid
on a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for the first, second, and third layer respectively (assuming the backend is PyTorch):>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15]) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_nid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, output_nodes, blocks in dataloader: ... train_on(blocks)
If training on a heterogeneous graph and you want different number of neighbors for each edge type, one should instead provide a list of dicts. Each dict would specify the number of neighbors to pick per edge type.
>>> sampler = dgl.dataloading.NeighborSampler([ ... {('user', 'follows', 'user'): 5, ... ('user', 'plays', 'game'): 4, ... ('game', 'played-by', 'user'): 3}] * 3)
If you would like non-uniform neighbor sampling:
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works >>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p')
Edge classification and link prediction
This class can also work for edge classification and link prediction together with
as_edge_prediction_sampler()
.>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15]) >>> sampler = dgl.dataloading.as_edge_prediction_sampler(sampler) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
See the documentation
as_edge_prediction_sampler()
for more details.Notes
For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.
-
__init__
(fanouts, edge_dir='in', prob=None, replace=False, prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None)[source]¶ Initialize self. See help(type(self)) for accurate signature.
Methods
__init__
(fanouts[, edge_dir, prob, replace, …])Initialize self.
assign_lazy_features
(result)Assign lazy features for prefetching.
sample
(g, seed_nodes[, exclude_eids])Sample a list of blocks from the given seed nodes.
sample_blocks
(g, seed_nodes[, exclude_eids])Generates a list of blocks from the given seed nodes.