Make Your Own Dataset

This tutorial assumes that you already know the basics of training a GNN for node classification and how to create, load, and store a DGL graph.

By the end of this tutorial, you will be able to

  • Create your own graph dataset for node classification, link prediction, or graph classification.

(Time estimate: 15 minutes)

DGLDataset Object Overview

Your custom graph dataset should inherit the dgl.data.DGLDataset class and implement the following methods:

  • __getitem__(self, i): retrieve the i-th example of the dataset. An example often contains a single DGL graph, and occasionally its label.

  • __len__(self): the number of examples in the dataset.

  • process(self): load and process raw data from disk.

Creating a Dataset for Graph Classification from CSV

Creating a graph classification dataset involves implementing __getitem__ to return both the graph and its graph-level label.

This tutorial demonstrates how to create a graph classification dataset with the following synthetic CSV data:

  • graph_edges.csv: containing three columns:

    • graph_id: the ID of the graph.

    • src: the source node of an edge of the given graph.

    • dst: the destination node of an edge of the given graph.

  • graph_properties.csv: containing three columns:

    • graph_id: the ID of the graph.

    • label: the label of the graph.

    • num_nodes: the number of nodes in the graph.

urllib.request.urlretrieve(
    'https://data.dgl.ai/tutorial/dataset/graph_edges.csv', './graph_edges.csv')
urllib.request.urlretrieve(
    'https://data.dgl.ai/tutorial/dataset/graph_properties.csv', './graph_properties.csv')
edges = pd.read_csv('./graph_edges.csv')
properties = pd.read_csv('./graph_properties.csv')

edges.head()

properties.head()

class SyntheticDataset(DGLDataset):
    def __init__(self):
        super().__init__(name='synthetic')

    def process(self):
        edges = pd.read_csv('./graph_edges.csv')
        properties = pd.read_csv('./graph_properties.csv')
        self.graphs = []
        self.labels = []

        # Create a graph for each graph ID from the edges table.
        # First process the properties table into two dictionaries with graph IDs as keys.
        # The label and number of nodes are values.
        label_dict = {}
        num_nodes_dict = {}
        for _, row in properties.iterrows():
            label_dict[row['graph_id']] = row['label']
            num_nodes_dict[row['graph_id']] = row['num_nodes']

        # For the edges, first group the table by graph IDs.
        edges_group = edges.groupby('graph_id')

        # For each graph ID...
        for graph_id in edges_group.groups:
            # Find the edges as well as the number of nodes and its label.
            edges_of_id = edges_group.get_group(graph_id)
            src = edges_of_id['src'].to_numpy()
            dst = edges_of_id['dst'].to_numpy()
            num_nodes = num_nodes_dict[graph_id]
            label = label_dict[graph_id]

            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes)
            self.graphs.append(g)
            self.labels.append(label)

        # Convert the label list to tensor for saving.
        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)

dataset = SyntheticDataset()
graph, label = dataset[0]
print(graph, label)


# Thumbnail credits: (Un)common Use Cases for Graph Databases, Michal Bachman
# sphinx_gallery_thumbnail_path = '_static/blitz_6_load_data.png'

Out:

Graph(num_nodes=15, num_edges=45,
      ndata_schemes={}
      edata_schemes={}) tensor(0)

Total running time of the script: ( 0 minutes 0.931 seconds)

Gallery generated by Sphinx-Gallery