Source code for dgl.data.sbm

"""Dataset for stochastic block model."""
import math
import os
import random

import numpy as np
import numpy.random as npr
import scipy as sp

from .. import batch
from ..convert import from_scipy
from .dgl_dataset import DGLDataset
from .utils import load_graphs, load_info, save_graphs, save_info


def sbm(n_blocks, block_size, p, q, rng=None):
    """(Symmetric) Stochastic Block Model

    Parameters
    ----------
    n_blocks : int
        Number of blocks.
    block_size : int
        Block size.
    p : float
        Probability for intra-community edge.
    q : float
        Probability for inter-community edge.
    rng : numpy.random.RandomState, optional
        Random number generator.

    Returns
    -------
    scipy sparse matrix
        The adjacency matrix of generated graph.
    """
    n = n_blocks * block_size
    p /= n
    q /= n
    rng = np.random.RandomState() if rng is None else rng

    rows = []
    cols = []
    for i in range(n_blocks):
        for j in range(i, n_blocks):
            density = p if i == j else q
            block = sp.sparse.random(
                block_size,
                block_size,
                density,
                random_state=rng,
                data_rvs=lambda n: np.ones(n),
            )
            rows.append(block.row + i * block_size)
            cols.append(block.col + j * block_size)

    rows = np.hstack(rows)
    cols = np.hstack(cols)
    a = sp.sparse.coo_matrix(
        (np.ones(rows.shape[0]), (rows, cols)), shape=(n, n)
    )
    adj = sp.sparse.triu(a) + sp.sparse.triu(a, 1).transpose()
    return adj


[docs]class SBMMixtureDataset(DGLDataset): r"""Symmetric Stochastic Block Model Mixture Reference: Appendix C of `Supervised Community Detection with Hierarchical Graph Neural Networks <https://arxiv.org/abs/1705.08415>`_ Parameters ---------- n_graphs : int Number of graphs. n_nodes : int Number of nodes. n_communities : int Number of communities. k : int, optional Multiplier. Default: 2 avg_deg : int, optional Average degree. Default: 3 pq : list of pair of nonnegative float or str, optional Random densities. This parameter is for future extension, for now it's always using the default value. Default: Appendix_C rng : numpy.random.RandomState, optional Random number generator. If not given, it's numpy.random.RandomState() with `seed=None`, which read data from /dev/urandom (or the Windows analogue) if available or seed from the clock otherwise. Default: None Raises ------ RuntimeError is raised if pq is not a list or string. Examples -------- >>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2) >>> from torch.utils.data import DataLoader >>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn) >>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader: ... # your code here """ def __init__( self, n_graphs, n_nodes, n_communities, k=2, avg_deg=3, pq="Appendix_C", rng=None, ): self._n_graphs = n_graphs self._n_nodes = n_nodes self._n_communities = n_communities assert n_nodes % n_communities == 0 self._block_size = n_nodes // n_communities self._k = k self._avg_deg = avg_deg self._pq = pq self._rng = rng super(SBMMixtureDataset, self).__init__( name="sbmmixture", hash_key=(n_graphs, n_nodes, n_communities, k, avg_deg, pq, rng), ) def process(self): pq = self._pq if type(pq) is list: assert len(pq) == self._n_graphs elif type(pq) is str: generator = {"Appendix_C": self._appendix_c}[pq] pq = [generator() for _ in range(self._n_graphs)] else: raise RuntimeError() self._graphs = [ from_scipy(sbm(self._n_communities, self._block_size, *x)) for x in pq ] self._line_graphs = [ g.line_graph(backtracking=False) for g in self._graphs ] in_degrees = lambda g: g.in_degrees().float() self._graph_degrees = [in_degrees(g) for g in self._graphs] self._line_graph_degrees = [in_degrees(lg) for lg in self._line_graphs] self._pm_pds = list(zip(*[g.edges() for g in self._graphs]))[0] def has_cache(self): graph_path = os.path.join( self.save_path, "graphs_{}.bin".format(self.hash) ) line_graph_path = os.path.join( self.save_path, "line_graphs_{}.bin".format(self.hash) ) info_path = os.path.join( self.save_path, "info_{}.pkl".format(self.hash) ) return ( os.path.exists(graph_path) and os.path.exists(line_graph_path) and os.path.exists(info_path) ) def save(self): graph_path = os.path.join( self.save_path, "graphs_{}.bin".format(self.hash) ) line_graph_path = os.path.join( self.save_path, "line_graphs_{}.bin".format(self.hash) ) info_path = os.path.join( self.save_path, "info_{}.pkl".format(self.hash) ) save_graphs(graph_path, self._graphs) save_graphs(line_graph_path, self._line_graphs) save_info( info_path, { "graph_degree": self._graph_degrees, "line_graph_degree": self._line_graph_degrees, "pm_pds": self._pm_pds, }, ) def load(self): graph_path = os.path.join( self.save_path, "graphs_{}.bin".format(self.hash) ) line_graph_path = os.path.join( self.save_path, "line_graphs_{}.bin".format(self.hash) ) info_path = os.path.join( self.save_path, "info_{}.pkl".format(self.hash) ) self._graphs, _ = load_graphs(graph_path) self._line_graphs, _ = load_graphs(line_graph_path) info = load_info(info_path) self._graph_degrees = info["graph_degree"] self._line_graph_degrees = info["line_graph_degree"] self._pm_pds = info["pm_pds"]
[docs] def __len__(self): r"""Number of graphs in the dataset.""" return len(self._graphs)
[docs] def __getitem__(self, idx): r"""Get one example by index Parameters ---------- idx : int Item index Returns ------- graph: :class:`dgl.DGLGraph` The original graph line_graph: :class:`dgl.DGLGraph` The line graph of `graph` graph_degree: numpy.ndarray In degrees for each node in `graph` line_graph_degree: numpy.ndarray In degrees for each node in `line_graph` pm_pd: numpy.ndarray Edge indicator matrices Pm and Pd """ return ( self._graphs[idx], self._line_graphs[idx], self._graph_degrees[idx], self._line_graph_degrees[idx], self._pm_pds[idx], )
def _appendix_c(self): q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg)) p = self._k * self._avg_deg - q if random.random() < 0.5: return p, q else: return q, p
[docs] def collate_fn(self, x): r"""The `collate` function for dataloader Parameters ---------- x : tuple a batch of data that contains: - graph: :class:`dgl.DGLGraph` The original graph - line_graph: :class:`dgl.DGLGraph` The line graph of `graph` - graph_degree: numpy.ndarray In degrees for each node in `graph` - line_graph_degree: numpy.ndarray In degrees for each node in `line_graph` - pm_pd: numpy.ndarray Edge indicator matrices Pm and Pd Returns ------- g_batch: :class:`dgl.DGLGraph` Batched graphs lg_batch: :class:`dgl.DGLGraph` Batched line graphs degg_batch: numpy.ndarray A batch of in degrees for each node in `g_batch` deglg_batch: numpy.ndarray A batch of in degrees for each node in `lg_batch` pm_pd_batch: numpy.ndarray A batch of edge indicator matrices Pm and Pd """ g, lg, deg_g, deg_lg, pm_pd = zip(*x) g_batch = batch.batch(g) lg_batch = batch.batch(lg) degg_batch = np.concatenate(deg_g, axis=0) deglg_batch = np.concatenate(deg_lg, axis=0) pm_pd_batch = np.concatenate( [x + i * self._n_nodes for i, x in enumerate(pm_pd)], axis=0 ) return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch
SBMMixture = SBMMixtureDataset