InSubgraphSamplerΒΆ
-
class
dgl.graphbolt.
InSubgraphSampler
(datapipe, graph)[source]ΒΆ Bases:
torch.utils.data.datapipes.datapipe.IterDataPipe
[torch.utils.data.datapipes.iter.callable.T_co
]Sample the subgraph induced on the inbound edges of the given nodes.
Functional name:
sample_in_subgraph
.In-subgraph sampler is responsible for sampling a subgraph from given data, returning an induced subgraph along with compacted information.
- Parameters
datapipe (DataPipe) β The datapipe.
graph (FusedCSCSamplingGraph) β The graph on which to perform in_subgraph sampling.
Examples
>>> import dgl.graphbolt as gb >>> import torch >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14]) >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4]) >>> graph = gb.fused_csc_sampling_graph(indptr, indices) >>> item_set = gb.ItemSet(len(indptr) - 1, names="seed_nodes") >>> item_sampler = gb.ItemSampler(item_set, batch_size=2) >>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph) >>> for _, data in enumerate(insubgraph_sampler): ... print(data.sampled_subgraphs[0].sampled_csc) ... print(data.sampled_subgraphs[0].original_row_node_ids) ... print(data.sampled_subgraphs[0].original_column_node_ids) CSCFormatBase(indptr=tensor([0, 3, 5]), indices=tensor([0, 1, 2, 3, 4]), ) tensor([0, 1, 4, 2, 3]) tensor([0, 1]) CSCFormatBase(indptr=tensor([0, 2, 4]), indices=tensor([2, 3, 4, 0]), ) tensor([2, 3, 0, 5, 1]) tensor([2, 3]) CSCFormatBase(indptr=tensor([0, 3, 5]), indices=tensor([2, 3, 1, 4, 0]), ) tensor([4, 5, 0, 3, 1]) tensor([4, 5])
-
sample_subgraphs
(seeds, seeds_timestamp=None)[source]ΒΆ Sample subgraphs from the given seeds.
Any subclass of SubgraphSampler should implement this method.
- Parameters
seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) β The seed nodes.
- Returns
Union[torch.Tensor, Dict[str, torch.Tensor]] β The input nodes.
List[SampledSubgraph] β The sampled subgraphs.
Examples
>>> @functional_datapipe("my_sample_subgraph") >>> class MySubgraphSampler(SubgraphSampler): >>> def __init__(self, datapipe, graph, fanouts): >>> super().__init__(datapipe) >>> self.graph = graph >>> self.fanouts = fanouts >>> def sample_subgraphs(self, seeds): >>> # Sample subgraphs from the given seeds. >>> subgraphs = [] >>> subgraphs_nodes = [] >>> for fanout in reversed(self.fanouts): >>> subgraph = self.graph.sample_neighbors(seeds, fanout) >>> subgraphs.insert(0, subgraph) >>> subgraphs_nodes.append(subgraph.nodes) >>> seeds = subgraph.nodes >>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes)) >>> return subgraphs_nodes, subgraphs