SubgraphSamplerΒΆ

class dgl.graphbolt.SubgraphSampler(datapipe)[source]ΒΆ

Bases: torch.utils.data.datapipes.datapipe.IterDataPipe[torch.utils.data.datapipes.iter.callable.T_co]

A subgraph sampler used to sample a subgraph from a given set of nodes from a larger graph.

Functional name: sample_subgraph.

This class is the base class of all subgraph samplers. Any subclass of SubgraphSampler should implement the sample_subgraphs() method.

Parameters

datapipe (DataPipe) – The datapipe.

sample_subgraphs(seeds, seeds_timestamp=None)[source]ΒΆ

Sample subgraphs from the given seeds.

Any subclass of SubgraphSampler should implement this method.

Parameters

seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) – The seed nodes.

Returns

  • Union[torch.Tensor, Dict[str, torch.Tensor]] – The input nodes.

  • List[SampledSubgraph] – The sampled subgraphs.

Examples

>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>>     def __init__(self, datapipe, graph, fanouts):
>>>         super().__init__(datapipe)
>>>         self.graph = graph
>>>         self.fanouts = fanouts
>>>     def sample_subgraphs(self, seeds):
>>>         # Sample subgraphs from the given seeds.
>>>         subgraphs = []
>>>         subgraphs_nodes = []
>>>         for fanout in reversed(self.fanouts):
>>>             subgraph = self.graph.sample_neighbors(seeds, fanout)
>>>             subgraphs.insert(0, subgraph)
>>>             subgraphs_nodes.append(subgraph.nodes)
>>>             seeds = subgraph.nodes
>>>         subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>>         return subgraphs_nodes, subgraphs