Source code for dgl.graphbolt.utils

"""Utility functions for external use."""

from typing import Dict, Tuple, Union

import torch

from .minibatch import MiniBatch

[docs]def add_reverse_edges( edges: Union[ Dict[str, Tuple[torch.Tensor, torch.Tensor]], Tuple[torch.Tensor, torch.Tensor], ], reverse_etypes_mapping: Dict[str, str] = None, ): r""" This function finds the reverse edges of the given `edges` and returns the composition of them. In a homogeneous graph, reverse edges have inverted source and destination node IDs. While in a heterogeneous graph, reversing also involves swapping node IDs and their types. This function could be used before `exclude_edges` function to help find targeting edges. Note: The found reverse edges may not really exists in the original graph. And repeat edges could be added becasue reverse edges may already exists in the `edges`. Parameters ---------- edges : Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]], Tuple[torch.Tensor, torch.Tensor]] - If sampled subgraph is homogeneous, then `edges` should be a pair of of tensors. - If sampled subgraph is heterogeneous, then `edges` should be a dictionary of edge types and the corresponding edges to exclude. reverse_etypes_mapping : Dict[str, str], optional The mapping from the original edge types to their reverse edge types. Returns ------- Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]], Tuple[torch.Tensor, torch.Tensor]] The node pairs contain both the original edges and their reverse counterparts. Examples -------- >>> edges = {"A:r:B": (torch.tensor([0, 1]), torch.tensor([1, 2]))} >>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"})) {'A:r:B': (tensor([0, 1]), tensor([1, 2])), 'B:rr:A': (tensor([1, 2]), tensor([0, 1]))} >>> edges = (torch.tensor([0, 1]), torch.tensor([2, 1])) >>> print(gb.add_reverse_edges(edges)) (tensor([0, 1, 2, 1]), tensor([2, 1, 0, 1])) """ if isinstance(edges, tuple): u, v = edges return ([u, v]),[v, u])) else: combined_edges = edges.copy() for etype, reverse_etype in reverse_etypes_mapping.items(): if etype in edges: if reverse_etype in combined_edges: u, v = combined_edges[reverse_etype] u =[u, edges[etype][1]]) v =[v, edges[etype][0]]) combined_edges[reverse_etype] = (u, v) else: combined_edges[reverse_etype] = ( edges[etype][1], edges[etype][0], ) return combined_edges
[docs]def exclude_seed_edges( minibatch: MiniBatch, include_reverse_edges: bool = False, reverse_etypes_mapping: Dict[str, str] = None, ): """ Exclude seed edges with or without their reverse edges from the sampled subgraphs in the minibatch. Parameters ---------- minibatch : MiniBatch The minibatch. reverse_etypes_mapping : Dict[str, str] = None The mapping from the original edge types to their reverse edge types. """ edges_to_exclude = minibatch.node_pairs if include_reverse_edges: edges_to_exclude = add_reverse_edges( minibatch.node_pairs, reverse_etypes_mapping ) minibatch.sampled_subgraphs = [ subgraph.exclude_edges(edges_to_exclude) for subgraph in minibatch.sampled_subgraphs ] return minibatch