dgl.dataloading.as_edge_prediction_sampler

dgl.dataloading.as_edge_prediction_sampler(sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, prefetch_labels=None)[source]

Create an edge-wise sampler from a node-wise sampler.

For each batch of edges, the sampler applies the provided node-wise sampler to their source and destination nodes to extract subgraphs. It also generates negative edges if a negative sampler is provided, and extract subgraphs for their incident nodes as well.

For each iteration, the sampler will yield

  • A tensor of input nodes necessary for computing the representation on edges, or a dictionary of node type names and such tensors.

  • A subgraph that contains only the edges in the minibatch and their incident nodes. Note that the graph has an identical metagraph with the original graph.

  • If a negative sampler is given, another graph that contains the “negative edges”, connecting the source and destination nodes yielded from the given negative sampler.

  • The subgraphs or MFGs returned by the provided node-wise sampler, generated from the incident nodes of the edges in the minibatch (as well as those of the negative edges if applicable).

Parameters
  • sampler (Sampler) – The node-wise sampler object. It additionally requires that the sample method must have an optional third argument exclude_eids representing the edge IDs to exclude from neighborhood. The argument will be either a tensor for homogeneous graphs or a dict of edge types and tensors for heterogeneous graphs.

  • exclude (str, optional) –

    Whether and how to exclude dependencies related to the sampled edges in the minibatch. Possible values are

    • None, for not excluding any edges.

    • self, for excluding the edges in the current minibatch.

    • reverse_id, for excluding not only the edges in the current minibatch but also their reverse edges according to the ID mapping in the argument reverse_eids.

    • reverse_types, for excluding not only the edges in the current minibatch but also their reverse edges stored in another type according to the argument reverse_etypes.

    • User-defined exclusion rule. It is a callable with edges in the current minibatch as a single argument and should return the edges to be excluded.

  • reverse_eids (Tensor or dict[etype, Tensor], optional) –

    A tensor of reverse edge ID mapping. The i-th element indicates the ID of the i-th edge’s reverse edge.

    If the graph is heterogeneous, this argument requires a dictionary of edge types and the reverse edge ID mapping tensors.

  • reverse_etypes (dict[etype, etype], optional) – The mapping from the original edge types to their reverse edge types.

  • negative_sampler (callable, optional) – The negative sampler.

  • prefetch_labels (list[str] or dict[etype, list[str]], optional) –

    The edge labels to prefetch for the returned positive pair graph.

    See 6.8 Feature Prefetching for a detailed explanation of prefetching.

Examples

The following example shows how to train a 3-layer GNN for edge classification on a set of edges train_eid on a homogeneous undirected graph. Each node takes messages from all neighbors.

Given an array of source node IDs src and another array of destination node IDs dst, the following code creates a bidirectional graph:

>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))

Edge \(i\)’s reverse edge in the graph above is edge \(i + |E|\). Therefore, we can create a reverse edge mapping reverse_eids by:

>>> E = len(src)
>>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])

By passing reverse_eids to the edge sampler, the edges in the current mini-batch and their reversed edges will be excluded from the extracted subgraphs to avoid information leakage.

>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_id', reverse_eids=reverse_eids)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

For link prediction, one can provide a negative sampler to sample negative edges. The code below uses DGL’s Uniform to generate 5 negative samples per edge:

>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     sampler, exclude='reverse_id', reverse_eids=reverse_eids,
...     negative_sampler=neg_sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)

For heterogeneous graphs, reverse edges may belong to a different relation. For example, the relations “user-click-item” and “item-click-by-user” in the graph below are mutual reverse.

>>> g = dgl.heterograph({
...     ('user', 'click', 'item'): (user, item),
...     ('item', 'clicked-by', 'user'): (item, user)})

To correctly exclude edges from each mini-batch, set exclude='reverse_types' and pass a dictionary {'click': 'clicked-by', 'clicked-by': 'click'} to the reverse_etypes argument.

>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'})
>>> dataloader = dgl.dataloading.DataLoader(
...     g, {'click': train_eid}, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

For link prediction, provide a negative sampler to generate negative samples:

>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> sampler = dgl.dataloading.as_edge_prediction_sampler(
...     dgl.dataloading.NeighborSampler([15, 10, 5]),
...     exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
...     negative_sampler=neg_sampler)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)