NeighborSamplerΒΆ
-
class
dgl.graphbolt.
NeighborSampler
(datapipe, graph, fanouts, replace=False, prob_name=None, deduplicate=True)[source]ΒΆ Bases:
torch.utils.data.datapipes.datapipe.IterDataPipe
[torch.utils.data.datapipes.iter.callable.T_co
]Sample neighbor edges from a graph and return a subgraph.
Functional name:
sample_neighbor
.Neighbor sampler is responsible for sampling a subgraph from given data. It returns an induced subgraph along with compacted information. In the context of a node classification task, the neighbor sampler directly utilizes the nodes provided as seed nodes. However, in scenarios involving link prediction, the process needs another pre-peocess operation. That is, gathering unique nodes from the given node pairs, encompassing both positive and negative node pairs, and employs these nodes as the seed nodes for subsequent steps.
- Parameters
datapipe (DataPipe) β The datapipe.
graph (FusedCSCSamplingGraph) β The graph on which to perform subgraph sampling.
fanouts (list[torch.Tensor] or list[int]) β The number of edges to be sampled for each node with or without considering edge types. The length of this parameter implicitly signifies the layer of sampling being conducted. Note: The fanout order is from the outermost layer to innermost layer. For example, the fanout β[15, 10, 5]β means that 15 to the outermost layer, 10 to the intermediate layer and 5 corresponds to the innermost layer.
replace (bool) β Boolean indicating whether the sample is preformed with or without replacement. If True, a value can be selected multiple times. Otherwise, each value can be selected only once.
prob_name (str, optional) β The name of an edge attribute used as the weights of sampling for each node. This attribute tensor should contain (unnormalized) probabilities corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements equalling the total number of edges.
deduplicate (bool) β Boolean indicating whether seeds between hops will be deduplicated. If True, the same elements in seeds will be deleted to only one. Otherwise, the same elements will be remained.
Examples
>>> import torch >>> import dgl.graphbolt as gb >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> graph = gb.fused_csc_sampling_graph(indptr, indices) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) >>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> datapipe = gb.ItemSampler(item_set, batch_size=1) >>> datapipe = datapipe.sample_uniform_negative(graph, 2) >>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15]) >>> next(iter(datapipe)).sampled_subgraphs [SampledSubgraphImpl(sampled_csc=CSCFormatBase( indptr=tensor([0, 2, 4, 5, 6, 7, 8]), indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]), ), original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]), original_edge_ids=None, original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]), ), SampledSubgraphImpl(sampled_csc=CSCFormatBase( indptr=tensor([0, 2, 4, 5, 6, 7, 8]), indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]), ), original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]), original_edge_ids=None, original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]), ), SampledSubgraphImpl(sampled_csc=CSCFormatBase( indptr=tensor([0, 2, 4, 5, 6]), indices=tensor([1, 4, 0, 5, 5, 3]), ), original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]), original_edge_ids=None, original_column_node_ids=tensor([0, 1, 4, 5]), )]
-
sample_subgraphs
(seeds, seeds_timestamp=None)[source]ΒΆ Sample subgraphs from the given seeds.
Any subclass of SubgraphSampler should implement this method.
- Parameters
seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) β The seed nodes.
- Returns
Union[torch.Tensor, Dict[str, torch.Tensor]] β The input nodes.
List[SampledSubgraph] β The sampled subgraphs.
Examples
>>> @functional_datapipe("my_sample_subgraph") >>> class MySubgraphSampler(SubgraphSampler): >>> def __init__(self, datapipe, graph, fanouts): >>> super().__init__(datapipe) >>> self.graph = graph >>> self.fanouts = fanouts >>> def sample_subgraphs(self, seeds): >>> # Sample subgraphs from the given seeds. >>> subgraphs = [] >>> subgraphs_nodes = [] >>> for fanout in reversed(self.fanouts): >>> subgraph = self.graph.sample_neighbors(seeds, fanout) >>> subgraphs.insert(0, subgraph) >>> subgraphs_nodes.append(subgraph.nodes) >>> seeds = subgraph.nodes >>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes)) >>> return subgraphs_nodes, subgraphs