Source code for dgl.data.graph_serialize

"""For Graph Serialization"""
from __future__ import absolute_import
import os
from ..base import dgl_warning, DGLError
from ..heterograph import DGLHeteroGraph
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F
from .heterograph_serialize import save_heterographs

_init_api("dgl.data.graph_serialize")

__all__ = ['save_graphs', "load_graphs", "load_labels"]


@register_object("graph_serialize.StorageMetaData")
class StorageMetaData(ObjectBase):
    """StorageMetaData Object
    attributes available:
      num_graph [int]: return numbers of graphs
      nodes_num_list Value of NDArray: return number of nodes for each graph
      edges_num_list Value of NDArray: return number of edges for each graph
      labels [dict of backend tensors]: return dict of labels
      graph_data [list of GraphData]: return list of GraphData Object
    """


def is_local_path(filepath):
    return not (filepath.startswith("hdfs://") or
                filepath.startswith("viewfs://") or
                filepath.startswith("s3://"))


def check_local_file_exists(filename):
    if is_local_path(filename) and not os.path.exists(filename):
        raise DGLError("File {} does not exist.".format(filename))

@register_object("graph_serialize.GraphData")
class GraphData(ObjectBase):
    """GraphData Object"""

    @staticmethod
    def create(g):
        """Create GraphData"""
        # TODO(zihao): support serialize batched graph in the future.
        assert g.batch_size == 1, "Batched DGLGraph is not supported for serialization"
        ghandle = g._graph
        if len(g.ndata) != 0:
            node_tensors = dict()
            for key, value in g.ndata.items():
                node_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
        else:
            node_tensors = None

        if len(g.edata) != 0:
            edge_tensors = dict()
            for key, value in g.edata.items():
                edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
        else:
            edge_tensors = None

        return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors)

    def get_graph(self):
        """Get DGLHeteroGraph from GraphData"""
        ghandle = _CAPI_GDataGraphHandle(self)
        hgi =_CAPI_DGLAsHeteroGraph(ghandle)
        g = DGLHeteroGraph(hgi, ['_U'], ['_E'])
        node_tensors_items = _CAPI_GDataNodeTensors(self).items()
        edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
        for k, v in node_tensors_items:
            g.ndata[k] = F.zerocopy_from_dgl_ndarray(v)
        for k, v in edge_tensors_items:
            g.edata[k] = F.zerocopy_from_dgl_ndarray(v)
        return g


[docs]def save_graphs(filename, g_list, labels=None): r"""Save graphs and optionally their labels to file. Besides saving to local files, DGL supports writing the graphs directly to S3 (by providing a ``"s3://..."`` path) or to HDFS (by providing ``"hdfs://..."`` a path). The function saves both the graph structure and node/edge features to file in DGL's own binary format. For graph-level features, pass them via the :attr:`labels` argument. Parameters ---------- filename : str The file name to store the graphs and labels. g_list: list The graphs to be saved. labels: dict[str, Tensor] labels should be dict of tensors, with str as keys Examples ---------- >>> import dgl >>> import torch as th Create :class:`DGLGraph` objects and initialize node and edge features. >>> g1 = dgl.graph(([0, 1, 2], [1, 2, 3])) >>> g2 = dgl.graph(([0, 2], [2, 3])) >>> g2.edata["e"] = th.ones(2, 4) Save Graphs into file >>> from dgl.data.utils import save_graphs >>> graph_labels = {"glabel": th.tensor([0, 1])} >>> save_graphs("./data.bin", [g1, g2], graph_labels) See Also -------- load_graphs """ # if it is local file, do some sanity check if is_local_path(filename): if os.path.isdir(filename): raise DGLError("Filename {} is an existing directory.".format(filename)) f_path = os.path.dirname(filename) if f_path and not os.path.exists(f_path): os.makedirs(f_path) g_sample = g_list[0] if isinstance(g_list, list) else g_list if type(g_sample) == DGLHeteroGraph: # Doesn't support DGLHeteroGraph's derived class save_heterographs(filename, g_list, labels) else: raise DGLError( "Invalid argument g_list. Must be a DGLGraph or a list of DGLGraphs.")
[docs]def load_graphs(filename, idx_list=None): """Load graphs and optionally their labels from file saved by :func:`save_graphs`. Besides loading from local files, DGL supports loading the graphs directly from S3 (by providing a ``"s3://..."`` path) or from HDFS (by providing ``"hdfs://..."`` a path). Parameters ---------- filename: str The file name to load graphs from. idx_list: list[int], optional The indices of the graphs to be loaded if the file contains multiple graphs. Default is loading all the graphs stored in the file. Returns -------- graph_list: list[DGLGraph] The loaded graphs. labels: dict[str, Tensor] The graph labels stored in file. If no label is stored, the dictionary is empty. Regardless of whether the ``idx_list`` argument is given or not, the returned dictionary always contains the labels of all the graphs. Examples ---------- Following the example in :func:`save_graphs`. >>> from dgl.data.utils import load_graphs >>> glist, label_dict = load_graphs("./data.bin") # glist will be [g1, g2] >>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1] See Also -------- save_graphs """ # if it is local file, do some sanity check check_local_file_exists(filename) version = _CAPI_GetFileVersion(filename) if version == 1: dgl_warning( "You are loading a graph file saved by old version of dgl. \ Please consider saving it again with the current format.") return load_graph_v1(filename, idx_list) elif version == 2: return load_graph_v2(filename, idx_list) else: raise DGLError("Invalid DGL Version Number.")
def load_graph_v2(filename, idx_list=None): """Internal functions for loading DGLHeteroGraphs.""" if idx_list is None: idx_list = [] assert isinstance(idx_list, list) heterograph_list = _CAPI_LoadGraphFiles_V2(filename, idx_list) label_dict = load_labels_v2(filename) return [gdata.get_graph() for gdata in heterograph_list], label_dict def load_graph_v1(filename, idx_list=None): """"Internal functions for loading DGLGraphs (V0).""" if idx_list is None: idx_list = [] assert isinstance(idx_list, list) metadata = _CAPI_LoadGraphFiles_V1(filename, idx_list, False) label_dict = {} for k, v in metadata.labels.items(): label_dict[k] = F.zerocopy_from_dgl_ndarray(v) return [gdata.get_graph() for gdata in metadata.graph_data], label_dict
[docs]def load_labels(filename): """ Load label dict from file Parameters ---------- filename: str filename to load DGLGraphs Returns ---------- labels: dict dict of labels stored in file (empty dict returned if no label stored) Examples ---------- Following the example in save_graphs. >>> from dgl.data.utils import load_labels >>> label_dict = load_graphs("./data.bin") """ # if it is local file, do some sanity check check_local_file_exists(filename) version = _CAPI_GetFileVersion(filename) if version == 1: return load_labels_v1(filename) elif version == 2: return load_labels_v2(filename) else: raise Exception("Invalid DGL Version Number")
def load_labels_v2(filename): """Internal functions for loading labels from V2 format""" label_dict = {} nd_dict = _CAPI_LoadLabels_V2(filename) for k, v in nd_dict.items(): label_dict[k] = F.zerocopy_from_dgl_ndarray(v) return label_dict def load_labels_v1(filename): """Internal functions for loading labels from V1 format""" metadata = _CAPI_LoadGraphFiles_V1(filename, [], True) label_dict = {} for k, v in metadata.labels.items(): label_dict[k] = F.zerocopy_from_dgl_ndarray(v) return label_dict