Source code for dgl.dataloading.negative_sampler

"""Negative samplers"""
from import Mapping
from .. import backend as F

class _BaseNegativeSampler(object):
    def _generate(self, g, eids, canonical_etype):
        raise NotImplementedError

    def __call__(self, g, eids):
        """Returns negative examples.

        g : DGLGraph
            The graph.
        eids : Tensor or dict[etype, Tensor]
            The sampled edges in the minibatch.

        tuple[Tensor, Tensor] or dict[etype, tuple[Tensor, Tensor]]
            The returned source-destination pairs as negative examples.
        if isinstance(eids, Mapping):
            eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
            neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()}
            assert len(g.etypes) == 1, \
                'please specify a dict of etypes and ids for graphs with multiple edge types'
            neg_pair = self._generate(g, eids, g.canonical_etypes[0])

        return neg_pair

[docs]class Uniform(_BaseNegativeSampler): """Negative sampler that randomly chooses negative destination nodes for each source node according to a uniform distribution. For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates :attr:`k` pairs of negative edges ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes of type ``dsttype``. The resulting edges will also have type ``(srctype, etype, dsttype)``. Parameters ---------- k : int The number of negative examples per edge. Examples -------- >>> g = dgl.graph(([0, 1, 2], [1, 2, 3])) >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(2) >>> neg_sampler(g, [0, 1]) (tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3])) """ def __init__(self, k): self.k = k def _generate(self, g, eids, canonical_etype): _, _, vtype = canonical_etype shape = F.shape(eids) dtype = F.dtype(eids) ctx = F.context(eids) shape = (shape[0] * self.k,) src, _ = g.find_edges(eids, etype=canonical_etype) src = F.repeat(src, self.k, 0) dst = F.randint(shape, dtype, ctx, 0, g.number_of_nodes(vtype)) return src, dst