Source code for dgl.nn.mxnet.utils

"""Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name

from mxnet import nd, gluon
import numpy as np
from ... import DGLGraph

def matmul_maybe_select(A, B):
    """Perform Matrix multiplication C = A * B but A could be an integer id vector.

    If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.
    In this case, the expensive dense matrix multiply can be replaced by a much
    cheaper index lookup.

    For example,
    ::

        A = [2, 0, 1],
        B = [[0.1, 0.2],
             [0.3, 0.4],
             [0.5, 0.6]]

    then matmul_maybe_select(A, B) is equivalent to
    ::

        [[0, 0, 1],     [[0.1, 0.2],
         [1, 0, 0],  *   [0.3, 0.4],
         [0, 1, 0]]      [0.5, 0.6]]

    In all other cases, perform a normal matmul.

    Parameters
    ----------
    A : mxnet.NDArray
        lhs tensor
    B : mxnet.NDArray
        rhs tensor

    Returns
    -------
    C : mxnet.NDArray
        result tensor
    """
    if A.dtype in (np.int32, np.int64) and len(A.shape) == 1:
        return nd.take(B, A, axis=0)
    else:
        return nd.dot(A, B)

def bmm_maybe_select(A, B, index):
    """Slice submatrices of A by the given index and perform bmm.

    B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of
    N matrices of shape (D1, D2). The input index is an integer vector of length M.
    A could be either:
    (1) a dense tensor of shape (M, D1),
    (2) an integer vector of length M.
    The result C is a 2D matrix of shape (M, D2)

    For case (1), C is computed by bmm:
    ::

        C[i, :] = matmul(A[i, :], B[index[i], :, :])

    For case (2), C is computed by index select:
    ::

        C[i, :] = B[index[i], A[i], :]

    Parameters
    ----------
    A : mxnet.NDArray
        lhs tensor
    B : mxnet.NDArray
        rhs tensor
    index : mxnet.NDArray
        index tensor

    Returns
    -------
    C : mxnet.NDArray
        return tensor
    """
    if A.dtype in (np.int32, np.int64) and len(A.shape) == 1:
        return B[index, A, :]
    else:
        BB = nd.take(B, index, axis=0)
        return nd.batch_dot(A.expand_dims(1), BB).squeeze(1)

def normalize(x, p=2, axis=1, eps=1e-12):
    r"""Performs :math:`L_p` normalization of inputs over specified dimension.

    For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
    :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as

    .. math::
        v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.

    With the default arguments it uses the Euclidean norm over vectors along dimension
     :math:`1` for normalization.

    Args:
        x: input ndarray of any shape
        ord (float): the exponent value in the norm formulation. Default: 2
        dim (int): the dimension to reduce. Default: 1
        eps (float): small value to avoid division by zero. Default: 1e-12
    """
    denom = nd.clip(nd.norm(x, ord=p, axis=axis, keepdims=True), eps, float('inf'))
    return x / denom

[docs]class Sequential(gluon.nn.Sequential): r"""A squential container for stacking graph neural network blocks. We support two modes: sequentially apply GNN blocks on the same graph or a list of given graphs. In the second case, the number of graphs equals the number of blocks inside this container. Examples -------- Mode 1: sequentially apply GNN modules on the same graph >>> import dgl >>> from mxnet import nd >>> from mxnet.gluon import nn >>> import dgl.function as fn >>> from dgl.nn.mxnet import Sequential >>> class ExampleLayer(nn.Block): >>> def __init__(self, **kwargs): >>> super().__init__(**kwargs) >>> def forward(self, graph, n_feat, e_feat): >>> with graph.local_scope(): >>> graph.ndata['h'] = n_feat >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> n_feat += graph.ndata['h'] >>> graph.apply_edges(fn.u_add_v('h', 'h', 'e')) >>> e_feat += graph.edata['e'] >>> return n_feat, e_feat >>> >>> g = dgl.DGLGraph() >>> g.add_nodes(3) >>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2]) >>> net = Sequential() >>> net.add(ExampleLayer()) >>> net.add(ExampleLayer()) >>> net.add(ExampleLayer()) >>> net.initialize() >>> n_feat = nd.random.randn(3, 4) >>> e_feat = nd.random.randn(9, 4) >>> net(g, n_feat, e_feat) ( [[ 12.412863 99.61184 21.472883 -57.625923 ] [ 10.08097 100.68611 20.627377 -60.13458 ] [ 11.7912245 101.80654 22.427956 -58.32772 ]] <NDArray 3x4 @cpu(0)>, [[ 21.818504 198.12076 42.72387 -115.147736] [ 23.070837 195.49811 43.42292 -116.17203 ] [ 24.330334 197.10927 42.40048 -118.06538 ] [ 21.907919 199.11469 42.1187 -115.35658 ] [ 22.849625 198.79213 43.866085 -113.65381 ] [ 20.926125 198.116 42.64334 -114.246704] [ 23.003159 197.06662 41.796425 -117.14977 ] [ 21.391375 198.3348 41.428078 -116.30361 ] [ 21.291483 200.0701 40.8239 -118.07314 ]] <NDArray 9x4 @cpu(0)>) Mode 2: sequentially apply GNN modules on different graphs >>> import dgl >>> from mxnet import nd >>> from mxnet.gluon import nn >>> import dgl.function as fn >>> import networkx as nx >>> from dgl.nn.mxnet import Sequential >>> class ExampleLayer(nn.Block): >>> def __init__(self, **kwargs): >>> super().__init__(**kwargs) >>> def forward(self, graph, n_feat): >>> with graph.local_scope(): >>> graph.ndata['h'] = n_feat >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> n_feat += graph.ndata['h'] >>> return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1) >>> >>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)) >>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)) >>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)) >>> net = Sequential() >>> net.add(ExampleLayer()) >>> net.add(ExampleLayer()) >>> net.add(ExampleLayer()) >>> net.initialize() >>> n_feat = nd.random.randn(32, 4) >>> net([g1, g2, g3], n_feat) [[-101.289566 -22.584694 -89.25348 -151.6447 ] [-130.74239 -49.494812 -120.250854 -199.81546 ] [-112.32089 -50.036713 -116.13266 -190.38638 ] [-119.23065 -26.78553 -111.11185 -166.08322 ]] <NDArray 4x4 @cpu(0)> """ def __init__(self, prefix=None, params=None): super(Sequential, self).__init__(prefix=prefix, params=params)
[docs] def forward(self, graph, *feats): r"""Sequentially apply modules to the input. Parameters ---------- graph : DGLGraph or list of DGLGraphs The graph(s) to apply modules on. *feats : Input features. The output of :math:`i`-th block should match that of the input of :math:`(i+1)`-th block. """ if isinstance(graph, list): for graph_i, module in zip(graph, self): if not isinstance(feats, tuple): feats = (feats,) feats = module(graph_i, *feats) elif isinstance(graph, DGLGraph): for module in self: if not isinstance(feats, tuple): feats = (feats,) feats = module(graph, *feats) else: raise TypeError('The first argument of forward must be a DGLGraph' ' or a list of DGLGraph s') return feats