Source code for dgl.graphbolt.impl.sampled_subgraph_impl

"""Sampled subgraph for FusedCSCSamplingGraph."""
# pylint: disable= invalid-name
from dataclasses import dataclass
from typing import Dict, Union

import torch

from ..base import CSCFormatBase, etype_str_to_tuple
from ..internal import get_attributes
from ..sampled_subgraph import SampledSubgraph

__all__ = ["SampledSubgraphImpl"]


[docs]@dataclass class SampledSubgraphImpl(SampledSubgraph): r"""Sampled subgraph of CSCSamplingGraph. Examples -------- >>> sampled_csc = {"A:relation:B": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]), ... indices=torch.tensor([0, 1, 2]))} >>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> subgraph = gb.SampledSubgraphImpl( ... sampled_csc=sampled_csc, ... original_column_node_ids=original_column_node_ids, ... original_row_node_ids=original_row_node_ids, ... original_edge_ids=original_edge_ids ... ) >>> print(subgraph.sampled_csc) {"A:relation:B": CSCForamtBase(indptr=torch.tensor([0, 1, 2, 3]), ... indices=torch.tensor([0, 1, 2]))} >>> print(subgraph.original_column_node_ids) {'B': tensor([10, 11, 12])} >>> print(subgraph.original_row_node_ids) {'A': tensor([13, 14, 15])} >>> print(subgraph.original_edge_ids) {"A:relation:B": tensor([19, 20, 21])} """ sampled_csc: Union[CSCFormatBase, Dict[str, CSCFormatBase]] = None original_column_node_ids: Union[ Dict[str, torch.Tensor], torch.Tensor ] = None original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None def __post_init__(self): if isinstance(self.sampled_csc, dict): for etype, pair in self.sampled_csc.items(): assert ( isinstance(etype, str) and len(etype_str_to_tuple(etype)) == 3 ), "Edge type should be a string in format of str:str:str." assert ( pair.indptr is not None and pair.indices is not None ), "Node pair should be have indptr and indice." assert isinstance(pair.indptr, torch.Tensor) and isinstance( pair.indices, torch.Tensor ), "Nodes in pairs should be of type torch.Tensor." else: assert ( self.sampled_csc.indptr is not None and self.sampled_csc.indices is not None ), "Node pair should be have indptr and indice." assert isinstance( self.sampled_csc.indptr, torch.Tensor ) and isinstance( self.sampled_csc.indices, torch.Tensor ), "Nodes in pairs should be of type torch.Tensor." def __repr__(self) -> str: return _sampled_subgraph_str(self, "SampledSubgraphImpl")
def _sampled_subgraph_str(sampled_subgraph: SampledSubgraph, classname) -> str: final_str = classname + "(" attributes = get_attributes(sampled_subgraph) attributes.reverse() for name in attributes: val = getattr(sampled_subgraph, name) def _add_indent(_str, indent): lines = _str.split("\n") lines = [lines[0]] + [" " * indent + line for line in lines[1:]] return "\n".join(lines) val = str(val) final_str = ( final_str + f"{name}={_add_indent(val, len(name) + len(classname) + 1)},\n" + " " * len(classname) ) return final_str[: -len(classname)] + ")"