NeighborSamplerΒΆ

class dgl.graphbolt.NeighborSampler(datapipe, graph, fanouts, replace=False, prob_name=None, deduplicate=True)[source]ΒΆ

Bases: torch.utils.data.datapipes.datapipe.IterDataPipe[torch.utils.data.datapipes.iter.callable.T_co]

Sample neighbor edges from a graph and return a subgraph.

Functional name: sample_neighbor.

Neighbor sampler is responsible for sampling a subgraph from given data. It returns an induced subgraph along with compacted information. In the context of a node classification task, the neighbor sampler directly utilizes the nodes provided as seed nodes. However, in scenarios involving link prediction, the process needs another pre-peocess operation. That is, gathering unique nodes from the given node pairs, encompassing both positive and negative node pairs, and employs these nodes as the seed nodes for subsequent steps.

Parameters
  • datapipe (DataPipe) – The datapipe.

  • graph (FusedCSCSamplingGraph) – The graph on which to perform subgraph sampling.

  • fanouts (list[torch.Tensor] or list[int]) – The number of edges to be sampled for each node with or without considering edge types. The length of this parameter implicitly signifies the layer of sampling being conducted. Note: The fanout order is from the outermost layer to innermost layer. For example, the fanout β€˜[15, 10, 5]’ means that 15 to the outermost layer, 10 to the intermediate layer and 5 corresponds to the innermost layer.

  • replace (bool) – Boolean indicating whether the sample is preformed with or without replacement. If True, a value can be selected multiple times. Otherwise, each value can be selected only once.

  • prob_name (str, optional) – The name of an edge attribute used as the weights of sampling for each node. This attribute tensor should contain (unnormalized) probabilities corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements equalling the total number of edges.

  • deduplicate (bool) – Boolean indicating whether seeds between hops will be deduplicated. If True, the same elements in seeds will be deleted to only one. Otherwise, the same elements will be remained.

Examples

>>> import torch
>>> import dgl.graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> datapipe = gb.ItemSampler(item_set, batch_size=1)
>>> datapipe = datapipe.sample_uniform_negative(graph, 2)
>>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])
>>> next(iter(datapipe)).sampled_subgraphs
[SampledSubgraphImpl(sampled_csc=CSCFormatBase(
        indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
        indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),
    ),
    original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
    original_edge_ids=None,
    original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(
        indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
        indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),
    ),
    original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
    original_edge_ids=None,
    original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(
        indptr=tensor([0, 2, 4, 5, 6]),
        indices=tensor([1, 4, 0, 5, 5, 3]),
    ),
    original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
    original_edge_ids=None,
    original_column_node_ids=tensor([0, 1, 4, 5]),
)]
sample_subgraphs(seeds, seeds_timestamp=None)[source]ΒΆ

Sample subgraphs from the given seeds.

Any subclass of SubgraphSampler should implement this method.

Parameters

seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) – The seed nodes.

Returns

  • Union[torch.Tensor, Dict[str, torch.Tensor]] – The input nodes.

  • List[SampledSubgraph] – The sampled subgraphs.

Examples

>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>>     def __init__(self, datapipe, graph, fanouts):
>>>         super().__init__(datapipe)
>>>         self.graph = graph
>>>         self.fanouts = fanouts
>>>     def sample_subgraphs(self, seeds):
>>>         # Sample subgraphs from the given seeds.
>>>         subgraphs = []
>>>         subgraphs_nodes = []
>>>         for fanout in reversed(self.fanouts):
>>>             subgraph = self.graph.sample_neighbors(seeds, fanout)
>>>             subgraphs.insert(0, subgraph)
>>>             subgraphs_nodes.append(subgraph.nodes)
>>>             seeds = subgraph.nodes
>>>         subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>>         return subgraphs_nodes, subgraphs