"""Utility functions for external use."""
from typing import Dict, Union
import torch
from .minibatch import MiniBatch
[docs]def add_reverse_edges(
edges: Union[Dict[str, torch.Tensor], torch.Tensor],
reverse_etypes_mapping: Dict[str, str] = None,
):
r"""
This function finds the reverse edges of the given `edges` and returns the
composition of them. In a homogeneous graph, reverse edges have inverted
source and destination node IDs. While in a heterogeneous graph, reversing
also involves swapping node IDs and their types. This function could be
used before `exclude_edges` function to help find targeting edges.
Note: The found reverse edges may not really exists in the original graph.
And repeat edges could be added becasue reverse edges may already exists in
the `edges`.
Parameters
----------
edges : Union[Dict[str, torch.Tensor], torch.Tensor]
- If sampled subgraph is homogeneous, then `edges` should be a N*2
tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes_mapping : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
-------
Union[Dict[str, torch.Tensor], torch.Tensor]
The node pairs contain both the original edges and their reverse
counterparts.
Examples
--------
>>> edges = {"A:r:B": torch.tensor([[0, 1],[1, 2]]))}
>>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"}))
{'A:r:B': torch.tensor([[0, 1],[1, 2]]),
'B:rr:A': torch.tensor([[1, 0],[2, 1]])}
>>> edges = torch.tensor([[0, 1],[1, 2]])
>>> print(gb.add_reverse_edges(edges))
torch.tensor([[1, 0],[2, 1]])
"""
if isinstance(edges, torch.Tensor):
assert edges.ndim == 2 and edges.shape[1] == 2, (
"Only tensor with shape N*2 is supported now, but got "
+ f"{edges.shape}."
)
reverse_edges = edges.flip(dims=(1,))
return torch.cat((edges, reverse_edges))
else:
combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes_mapping.items():
if etype in edges:
assert edges[etype].ndim == 2 and edges[etype].shape[1] == 2, (
"Only tensor with shape N*2 is supported now, but got "
+ f"{edges[etype].shape}."
)
if reverse_etype in combined_edges:
combined_edges[reverse_etype] = torch.cat(
(
combined_edges[reverse_etype],
edges[etype].flip(dims=(1,)),
)
)
else:
combined_edges[reverse_etype] = edges[etype].flip(dims=(1,))
return combined_edges
[docs]def exclude_seed_edges(
minibatch: MiniBatch,
include_reverse_edges: bool = False,
reverse_etypes_mapping: Dict[str, str] = None,
):
"""
Exclude seed edges with or without their reverse edges from the sampled
subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude = minibatch.seeds
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
edges_to_exclude, reverse_etypes_mapping
)
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges_to_exclude)
for subgraph in minibatch.sampled_subgraphs
]
return minibatch