Source code for dgl.graphbolt.negative_sampler

"""Negative samplers."""

from _collections_abc import Mapping

from torch.utils.data import functional_datapipe

from .minibatch_transformer import MiniBatchTransformer

__all__ = [
    "NegativeSampler",
]


[docs]@functional_datapipe("sample_negative") class NegativeSampler(MiniBatchTransformer): """ A negative sampler used to generate negative samples and return a mix of positive and negative samples. Functional name: :obj:`sample_negative`. Parameters ---------- datapipe : DataPipe The datapipe. negative_ratio : int The proportion of negative samples to positive samples. """ def __init__( self, datapipe, negative_ratio, ): super().__init__(datapipe, self._sample) assert negative_ratio > 0, "Negative_ratio should be positive Integer." self.negative_ratio = negative_ratio def _sample(self, minibatch): """ Generate a mix of positive and negative samples. Parameters ---------- minibatch : MiniBatch An instance of 'MiniBatch' class requires the 'node_pairs' field. This function is responsible for generating negative edges corresponding to the positive edges defined by the 'node_pairs'. In cases where negative edges already exist, this function will overwrite them. Returns ------- MiniBatch An instance of 'MiniBatch' encompasses both positive and negative samples. """ node_pairs = minibatch.node_pairs assert node_pairs is not None if isinstance(node_pairs, Mapping): minibatch.negative_srcs, minibatch.negative_dsts = {}, {} for etype, pos_pairs in node_pairs.items(): self._collate( minibatch, self._sample_with_etype(pos_pairs, etype), etype ) else: self._collate(minibatch, self._sample_with_etype(node_pairs)) return minibatch def _sample_with_etype(self, node_pairs, etype=None): """Generate negative pairs for a given etype form positive pairs for a given etype. Parameters ---------- node_pairs : Tuple[Tensor, Tensor] A tuple of tensors that represent source-destination node pairs of positive edges, where positive means the edge must exist in the graph. etype : str Canonical edge type. Returns ------- Tuple[Tensor, Tensor] A collection of negative node pairs. """ raise NotImplementedError def _collate(self, minibatch, neg_pairs, etype=None): """Collates positive and negative samples into minibatch. Parameters ---------- minibatch : MiniBatch The input minibatch, which contains positive node pairs, will be filled with negative information in this function. neg_pairs : Tuple[Tensor, Tensor] A tuple of tensors represents source-destination node pairs of negative edges, where negative means the edge may not exist in the graph. etype : str Canonical edge type. """ neg_src, neg_dst = neg_pairs if neg_src is not None: neg_src = neg_src.view(-1, self.negative_ratio) if neg_dst is not None: neg_dst = neg_dst.view(-1, self.negative_ratio) if etype is not None: minibatch.negative_srcs[etype] = neg_src minibatch.negative_dsts[etype] = neg_dst else: minibatch.negative_srcs = neg_src minibatch.negative_dsts = neg_dst