Source code for dgl.graphbolt.subgraph_sampler

"""Subgraph samplers"""

from collections import defaultdict
from typing import Dict

import torch
from torch.utils.data import functional_datapipe

from .base import etype_str_to_tuple
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch_transformer import MiniBatchTransformer

__all__ = [
    "SubgraphSampler",
]


[docs]@functional_datapipe("sample_subgraph") class SubgraphSampler(MiniBatchTransformer): """A subgraph sampler used to sample a subgraph from a given set of nodes from a larger graph. Functional name: :obj:`sample_subgraph`. This class is the base class of all subgraph samplers. Any subclass of SubgraphSampler should implement either the :meth:`sample_subgraphs` method or the :meth:`sampling_stages` method to define the fine-grained sampling stages to take advantage of optimizations provided by the GraphBolt DataLoader. Parameters ---------- datapipe : DataPipe The datapipe. args : Non-Keyword Arguments Arguments to be passed into sampling_stages. kwargs : Keyword Arguments Arguments to be passed into sampling_stages. """ def __init__( self, datapipe, *args, **kwargs, ): datapipe = datapipe.transform(self._preprocess) datapipe = self.sampling_stages(datapipe, *args, **kwargs) datapipe = datapipe.transform(self._postprocess) super().__init__(datapipe) @staticmethod def _postprocess(minibatch): delattr(minibatch, "_seed_nodes") delattr(minibatch, "_seeds_timestamp") return minibatch @staticmethod def _preprocess(minibatch): if minibatch.node_pairs is not None: ( seeds, seeds_timestamp, minibatch.compacted_node_pairs, minibatch.compacted_negative_srcs, minibatch.compacted_negative_dsts, ) = SubgraphSampler._node_pairs_preprocess(minibatch) elif minibatch.seed_nodes is not None: seeds = minibatch.seed_nodes seeds_timestamp = ( minibatch.timestamp if hasattr(minibatch, "timestamp") else None ) elif minibatch.seeds is not None: ( seeds, seeds_timestamp, minibatch.compacted_seeds, ) = SubgraphSampler._seeds_preprocess(minibatch) else: raise ValueError( f"Invalid minibatch {minibatch}: One of `node_pairs`, " "`seed_nodes` and `seeds` should have a value." ) minibatch._seed_nodes = seeds minibatch._seeds_timestamp = seeds_timestamp return minibatch @staticmethod def _node_pairs_preprocess(minibatch): use_timestamp = hasattr(minibatch, "timestamp") node_pairs = minibatch.node_pairs neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts has_neg_src = neg_src is not None has_neg_dst = neg_dst is not None is_heterogeneous = isinstance(node_pairs, Dict) if is_heterogeneous: has_neg_src = has_neg_src and all( item is not None for item in neg_src.values() ) has_neg_dst = has_neg_dst and all( item is not None for item in neg_dst.values() ) # Collect nodes from all types of input. nodes = defaultdict(list) nodes_timestamp = None if use_timestamp: nodes_timestamp = defaultdict(list) for etype, (src, dst) in node_pairs.items(): src_type, _, dst_type = etype_str_to_tuple(etype) nodes[src_type].append(src) nodes[dst_type].append(dst) if use_timestamp: nodes_timestamp[src_type].append(minibatch.timestamp[etype]) nodes_timestamp[dst_type].append(minibatch.timestamp[etype]) if has_neg_src: for etype, src in neg_src.items(): src_type, _, _ = etype_str_to_tuple(etype) nodes[src_type].append(src.view(-1)) if use_timestamp: nodes_timestamp[src_type].append( minibatch.timestamp[etype].repeat_interleave( src.shape[-1] ) ) if has_neg_dst: for etype, dst in neg_dst.items(): _, _, dst_type = etype_str_to_tuple(etype) nodes[dst_type].append(dst.view(-1)) if use_timestamp: nodes_timestamp[dst_type].append( minibatch.timestamp[etype].repeat_interleave( dst.shape[-1] ) ) # Unique and compact the collected nodes. if use_timestamp: seeds, nodes_timestamp, compacted = compact_temporal_nodes( nodes, nodes_timestamp ) else: seeds, compacted = unique_and_compact(nodes) nodes_timestamp = None ( compacted_node_pairs, compacted_negative_srcs, compacted_negative_dsts, ) = ({}, {}, {}) # Map back in same order as collect. for etype, _ in node_pairs.items(): src_type, _, dst_type = etype_str_to_tuple(etype) src = compacted[src_type].pop(0) dst = compacted[dst_type].pop(0) compacted_node_pairs[etype] = (src, dst) if has_neg_src: for etype, _ in neg_src.items(): src_type, _, _ = etype_str_to_tuple(etype) compacted_negative_srcs[etype] = compacted[src_type].pop(0) compacted_negative_srcs[etype] = compacted_negative_srcs[ etype ].view(neg_src[etype].shape) if has_neg_dst: for etype, _ in neg_dst.items(): _, _, dst_type = etype_str_to_tuple(etype) compacted_negative_dsts[etype] = compacted[dst_type].pop(0) compacted_negative_dsts[etype] = compacted_negative_dsts[ etype ].view(neg_dst[etype].shape) else: # Collect nodes from all types of input. nodes = list(node_pairs) nodes_timestamp = None if use_timestamp: # Timestamp for source and destination nodes are the same. nodes_timestamp = [minibatch.timestamp, minibatch.timestamp] if has_neg_src: nodes.append(neg_src.view(-1)) if use_timestamp: nodes_timestamp.append( minibatch.timestamp.repeat_interleave(neg_src.shape[-1]) ) if has_neg_dst: nodes.append(neg_dst.view(-1)) if use_timestamp: nodes_timestamp.append( minibatch.timestamp.repeat_interleave(neg_dst.shape[-1]) ) # Unique and compact the collected nodes. if use_timestamp: seeds, nodes_timestamp, compacted = compact_temporal_nodes( nodes, nodes_timestamp ) else: seeds, compacted = unique_and_compact(nodes) nodes_timestamp = None # Map back in same order as collect. compacted_node_pairs = tuple(compacted[:2]) compacted = compacted[2:] if has_neg_src: compacted_negative_srcs = compacted.pop(0) # Since we need to calculate the neg_ratio according to the # compacted_negatvie_srcs shape, we need to reshape it back. compacted_negative_srcs = compacted_negative_srcs.view( neg_src.shape ) if has_neg_dst: compacted_negative_dsts = compacted.pop(0) # Same as above. compacted_negative_dsts = compacted_negative_dsts.view( neg_dst.shape ) return ( seeds, nodes_timestamp, compacted_node_pairs, compacted_negative_srcs if has_neg_src else None, compacted_negative_dsts if has_neg_dst else None, ) def _sample(self, minibatch): ( minibatch.input_nodes, minibatch.sampled_subgraphs, ) = self.sample_subgraphs( minibatch._seed_nodes, minibatch._seeds_timestamp ) return minibatch
[docs] def sampling_stages(self, datapipe): """The sampling stages are defined here by chaining to the datapipe. The default implementation expects :meth:`sample_subgraphs` to be implemented. To define fine-grained stages, this method should be overridden. """ return datapipe.transform(self._sample)
@staticmethod def _seeds_preprocess(minibatch): """Preprocess `seeds` in a minibatch to construct `unique_seeds`, `node_timestamp` and `compacted_seeds` for further sampling. It optionally incorporates timestamps for temporal graphs, organizing and compacting seeds based on their types and timestamps. Parameters ---------- minibatch: MiniBatch The minibatch. Returns ------- unique_seeds: torch.Tensor or Dict[str, torch.Tensor] A tensor or a dictionary of tensors representing the unique seeds. In heterogeneous graphs, seeds are returned for each node type. nodes_timestamp: None or a torch.Tensor or Dict[str, torch.Tensor] Containing timestamps for each seed. This is only returned if `minibatch` includes timestamps and the graph is temporal. compacted_seeds: torch.tensor or a Dict[str, torch.Tensor] Representation of compacted seeds corresponding to 'seeds', where all node ids inside are compacted. """ use_timestamp = hasattr(minibatch, "timestamp") seeds = minibatch.seeds is_heterogeneous = isinstance(seeds, Dict) if is_heterogeneous: # Collect nodes from all types of input. nodes = defaultdict(list) nodes_timestamp = None if use_timestamp: nodes_timestamp = defaultdict(list) for etype, pair in seeds.items(): assert pair.ndim == 1 or ( pair.ndim == 2 and pair.shape[1] == 2 ), ( "Only tensor with shape 1*N and N*2 is " + f"supported now, but got {pair.shape}." ) ntypes = etype[:].split(":")[::2] pair = pair.view(pair.shape[0], -1) if use_timestamp: negative_ratio = ( pair.shape[0] // minibatch.timestamp[etype].shape[0] - 1 ) neg_timestamp = minibatch.timestamp[ etype ].repeat_interleave(negative_ratio) for i, ntype in enumerate(ntypes): nodes[ntype].append(pair[:, i]) if use_timestamp: nodes_timestamp[ntype].append( minibatch.timestamp[etype] ) nodes_timestamp[ntype].append(neg_timestamp) # Unique and compact the collected nodes. if use_timestamp: ( unique_seeds, nodes_timestamp, compacted, ) = compact_temporal_nodes(nodes, nodes_timestamp) else: unique_seeds, compacted = unique_and_compact(nodes) nodes_timestamp = None compacted_seeds = {} # Map back in same order as collect. for etype, pair in seeds.items(): if pair.ndim == 1: compacted_seeds[etype] = compacted[etype].pop(0) else: src_type, _, dst_type = etype_str_to_tuple(etype) src = compacted[src_type].pop(0) dst = compacted[dst_type].pop(0) compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T else: # Collect nodes from all types of input. nodes = [seeds.view(-1)] nodes_timestamp = None if use_timestamp: # Timestamp for source and destination nodes are the same. negative_ratio = ( seeds.shape[0] // minibatch.timestamp.shape[0] - 1 ) neg_timestamp = minibatch.timestamp.repeat_interleave( negative_ratio ) seeds_timestamp = torch.cat( (minibatch.timestamp, neg_timestamp) ) nodes_timestamp = [seeds_timestamp for _ in range(seeds.ndim)] # Unique and compact the collected nodes. if use_timestamp: ( unique_seeds, nodes_timestamp, compacted, ) = compact_temporal_nodes(nodes, nodes_timestamp) else: unique_seeds, compacted = unique_and_compact(nodes) nodes_timestamp = None # Map back in same order as collect. compacted_seeds = compacted[0].view(seeds.shape) return ( unique_seeds, nodes_timestamp, compacted_seeds, )
[docs] def sample_subgraphs(self, seeds, seeds_timestamp): """Sample subgraphs from the given seeds, possibly with temporal constraints. Any subclass of SubgraphSampler should implement this method. Parameters ---------- seeds : Union[torch.Tensor, Dict[str, torch.Tensor]] The seed nodes. seeds_timestamp : Union[torch.Tensor, Dict[str, torch.Tensor]] The timestamps of the seed nodes. If given, the sampled subgraphs should not contain any nodes or edges that are newer than the timestamps of the seed nodes. Default: None. Returns ------- Union[torch.Tensor, Dict[str, torch.Tensor]] The input nodes. List[SampledSubgraph] The sampled subgraphs. Examples -------- >>> @functional_datapipe("my_sample_subgraph") >>> class MySubgraphSampler(SubgraphSampler): >>> def __init__(self, datapipe, graph, fanouts): >>> super().__init__(datapipe) >>> self.graph = graph >>> self.fanouts = fanouts >>> def sample_subgraphs(self, seeds): >>> # Sample subgraphs from the given seeds. >>> subgraphs = [] >>> subgraphs_nodes = [] >>> for fanout in reversed(self.fanouts): >>> subgraph = self.graph.sample_neighbors(seeds, fanout) >>> subgraphs.insert(0, subgraph) >>> subgraphs_nodes.append(subgraph.nodes) >>> seeds = subgraph.nodes >>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes)) >>> return subgraphs_nodes, subgraphs """ raise NotImplementedError