SAINTSamplerΒΆ
-
class
dgl.dataloading.
SAINTSampler
(mode, budget, cache=True, prefetch_ndata=None, prefetch_edata=None, output_device='cpu')[source]ΒΆ Bases:
dgl.dataloading.base.Sampler
Random node/edge/walk sampler from GraphSAINT: Graph Sampling Based Inductive Learning Method
For each call, the sampler samples a node subset and then returns a node induced subgraph. There are three options for sampling node subsets:
For
'node'
sampler, the probability to sample a node is in proportion to its out-degree.The
'edge'
sampler first samples an edge subset and then use the end nodes of the edges.The
'walk'
sampler uses the nodes visited by random walks. It uniformly selects a number of root nodes and then performs a fixed-length random walk from each root node.
- Parameters
mode (str) β The sampler to use, which can be
'node'
,'edge'
, or'walk'
.budget (int or tuple[int]) β
Sampler configuration.
For
'node'
sampler, budget specifies the number of nodes in each sampled subgraph.For
'edge'
sampler, budget specifies the number of edges to sample for inducing a subgraph.For
'walk'
sampler, budget is a tuple. budget[0] specifies the number of root nodes to generate random walks. budget[1] specifies the length of a random walk.
cache (bool, optional) β If False, it will not cache the probability arrays for sampling. Setting it to False is required if you want to use the sampler across different graphs.
prefetch_ndata (list[str], optional) β
The node data to prefetch for the subgraph.
See guide-minibatch-prefetching for a detailed explanation of prefetching.
prefetch_edata (list[str], optional) β
The edge data to prefetch for the subgraph.
See guide-minibatch-prefetching for a detailed explanation of prefetching.
output_device (device, optional) β The device of the output subgraphs.
Examples
>>> import torch >>> from dgl.dataloading import SAINTSampler, DataLoader >>> num_iters = 1000 >>> sampler = SAINTSampler(mode='node', budget=6000) >>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels >>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4) >>> for subg in dataloader: ... train_on(subg)