# Source code for dgl.nn.pytorch.utils

"""Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name

import torch as th
from torch import nn
import torch.nn.functional as F
from ... import DGLGraph
from ...base import dgl_warning
from ... import function as fn

def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector.

If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.
In this case, the expensive dense matrix multiply can be replaced by a much
cheaper index lookup.

For example,
::

A = [2, 0, 1],
B = [[0.1, 0.2],
[0.3, 0.4],
[0.5, 0.6]]

then matmul_maybe_select(A, B) is equivalent to
::

[[0, 0, 1],     [[0.1, 0.2],
[1, 0, 0],  *   [0.3, 0.4],
[0, 1, 0]]      [0.5, 0.6]]

In all other cases, perform a normal matmul.

Parameters
----------
A : torch.Tensor
lhs tensor
B : torch.Tensor
rhs tensor

Returns
-------
C : torch.Tensor
result tensor
"""
if A.dtype == th.int64 and len(A.shape) == 1:
return B.index_select(0, A)
else:
return th.matmul(A, B)

def bmm_maybe_select(A, B, index):
"""Slice submatrices of A by the given index and perform bmm.

B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of
N matrices of shape (D1, D2). The input index is an integer vector of length M.
A could be either:
(1) a dense tensor of shape (M, D1),
(2) an integer vector of length M.
The result C is a 2D matrix of shape (M, D2)

For case (1), C is computed by bmm:
::

C[i, :] = matmul(A[i, :], B[index[i], :, :])

For case (2), C is computed by index select:
::

C[i, :] = B[index[i], A[i], :]

Parameters
----------
A : torch.Tensor
lhs tensor
B : torch.Tensor
rhs tensor
index : torch.Tensor
index tensor

Returns
-------
C : torch.Tensor
return tensor
"""
if A.dtype == th.int64 and len(A.shape) == 1:
# following is a faster version of B[index, A, :]
B = B.view(-1, B.shape)
flatidx = index * B.shape + A
return B.index_select(0, flatidx)
else:
BB = B.index_select(0, index)
return th.bmm(A.unsqueeze(1), BB).squeeze()

# pylint: disable=W0235
class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive.
(Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future)
"""
def __init__(self):
super(Identity, self).__init__()

def forward(self, x):
"""Return input"""
return x

[docs]class Sequential(nn.Sequential):
r"""A sequential container for stacking graph neural network modules

DGL supports two modes: sequentially apply GNN modules on 1) the same graph or
2) a list of given graphs. In the second case, the number of graphs equals the
number of modules inside this container.

Parameters
----------
*args :
Sub-modules of torch.nn.Module that will be added to the container in
the order by which they are passed in the constructor.

Examples
--------
The following example uses PyTorch backend.

Mode 1: sequentially apply GNN modules on the same graph

>>> import torch
>>> import dgl
>>> import torch.nn as nn
>>> import dgl.function as fn
>>> from dgl.nn.pytorch import Sequential
>>> class ExampleLayer(nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>     def forward(self, graph, n_feat, e_feat):
>>>         with graph.local_scope():
>>>             graph.ndata['h'] = n_feat
>>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>>             n_feat += graph.ndata['h']
>>>             e_feat += graph.edata['e']
>>>             return n_feat, e_feat
>>>
>>> g = dgl.DGLGraph()
>>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
>>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
>>> n_feat = torch.rand(3, 4)
>>> e_feat = torch.rand(9, 4)
>>> net(g, n_feat, e_feat)
(tensor([[39.8597, 45.4542, 25.1877, 30.8086],
[40.7095, 45.3985, 25.4590, 30.0134],
[40.7894, 45.2556, 25.5221, 30.4220]]),
tensor([[80.3772, 89.7752, 50.7762, 60.5520],
[80.5671, 89.3736, 50.6558, 60.6418],
[80.4620, 89.5142, 50.3643, 60.3126],
[80.4817, 89.8549, 50.9430, 59.9108],
[80.2284, 89.6954, 50.0448, 60.1139],
[79.7846, 89.6882, 50.5097, 60.6213],
[80.2654, 90.2330, 50.2787, 60.6937],
[80.3468, 90.0341, 50.2062, 60.2659],
[80.0556, 90.2789, 50.2882, 60.5845]]))

Mode 2: sequentially apply GNN modules on different graphs

>>> import torch
>>> import dgl
>>> import torch.nn as nn
>>> import dgl.function as fn
>>> import networkx as nx
>>> from dgl.nn.pytorch import Sequential
>>> class ExampleLayer(nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>     def forward(self, graph, n_feat):
>>>         with graph.local_scope():
>>>             graph.ndata['h'] = n_feat
>>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>>             n_feat += graph.ndata['h']
>>>             return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)
>>>
>>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
>>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
>>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
>>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
>>> n_feat = torch.rand(32, 4)
>>> net([g1, g2, g3], n_feat)
tensor([[209.6221, 225.5312, 193.8920, 220.1002],
[250.0169, 271.9156, 240.2467, 267.7766],
[220.4007, 239.7365, 213.8648, 234.9637],
[196.4630, 207.6319, 184.2927, 208.7465]])
"""

def __init__(self, *args):
super(Sequential, self).__init__(*args)

[docs]    def forward(self, graph, *feats):
r"""

Sequentially apply modules to the input.

Parameters
----------
graph : DGLGraph or list of DGLGraphs
The graph(s) to apply modules on.

*feats :
Input features.
The output of the :math:i-th module should match the input
of the :math:(i+1)-th module in the sequential.
"""
if isinstance(graph, list):
for graph_i, module in zip(graph, self):
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph_i, *feats)
elif isinstance(graph, DGLGraph):
for module in self:
if not isinstance(feats, tuple):
feats = (feats,)
feats = module(graph, *feats)
else:
raise TypeError('The first argument of forward must be a DGLGraph'
' or a list of DGLGraph s')
return feats

[docs]class WeightBasis(nn.Module):
r"""Basis decomposition from Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>__

It can be described as below:

.. math::

W_o = \sum_{b=1}^B a_{ob} V_b

Each weight output :math:W_o is essentially a linear combination of basis
transformations :math:V_b with coefficients :math:a_{ob}.

If is useful as a form of regularization on a large parameter matrix. Thus,
the number of weight outputs is usually larger than the number of bases.

Parameters
----------
shape : tuple[int]
Shape of the basis parameter.
num_bases : int
Number of bases.
num_outputs : int
Number of outputs.
"""
def __init__(self,
shape,
num_bases,
num_outputs):
super(WeightBasis, self).__init__()
self.shape = shape
self.num_bases = num_bases
self.num_outputs = num_outputs

if num_outputs <= num_bases:
dgl_warning('The number of weight outputs should be larger than the number'
' of bases.')

self.weight = nn.Parameter(th.Tensor(self.num_bases, *shape))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_outputs, self.num_bases))
nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu'))

[docs]    def forward(self):
r"""Forward computation

Returns
-------
weight : torch.Tensor
Composed weight tensor of shape (num_outputs,) + shape
"""
# generate all weights from bases
weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))
return weight.view(self.num_outputs, *self.shape)

[docs]class JumpingKnowledge(nn.Module):
r"""The Jumping Knowledge aggregation module from Representation Learning on
Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>__

It aggregates the output representations of multiple GNN layers with

**concatenation**

.. math::

h_i^{(1)} \, \Vert \, \ldots \, \Vert \, h_i^{(T)}

or **max pooling**

.. math::

\max \left( h_i^{(1)}, \ldots, h_i^{(T)} \right)

or **LSTM**

.. math::

\sum_{t=1}^T \alpha_i^{(t)} h_i^{(t)}

with attention scores :math:\alpha_i^{(t)} obtained from a BiLSTM

Parameters
----------
mode : str
The aggregation to apply. It can be 'cat', 'max', or 'lstm',
corresponding to the equations above in order.
in_feats : int, optional
This argument is only required if :attr:mode is 'lstm'.
The output representation size of a single GNN layer. Note that
all GNN layers need to have the same output representation size.
num_layers : int, optional
This argument is only required if :attr:mode is 'lstm'.
The number of GNN layers for output aggregation.

Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import JumpingKnowledge

>>> # Output representations of two GNN layers
>>> num_nodes = 3
>>> in_feats = 4
>>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]

>>> # Case1
>>> model = JumpingKnowledge()
>>> model(feat_list).shape
torch.Size([3, 8])

>>> # Case2
>>> model = JumpingKnowledge(mode='max')
>>> model(feat_list).shape
torch.Size([3, 4])

>>> # Case3
>>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list))
>>> model(feat_list).shape
torch.Size([3, 4])
"""
def __init__(self, mode='cat', in_feats=None, num_layers=None):
super(JumpingKnowledge, self).__init__()
assert mode in ['cat', 'max', 'lstm'], \
"Expect mode to be 'cat', or 'max' or 'lstm', got {}".format(mode)
self.mode = mode

if mode == 'lstm':
assert in_feats is not None, 'in_feats is required for lstm mode'
assert num_layers is not None, 'num_layers is required for lstm mode'
hidden_size = (num_layers * in_feats) // 2
self.lstm = nn.LSTM(in_feats, hidden_size, bidirectional=True, batch_first=True)
self.att = nn.Linear(2 * hidden_size, 1)

[docs]    def reset_parameters(self):
r"""

Description
-----------
Reinitialize learnable parameters. This comes into effect only for the lstm mode.
"""
if self.mode == 'lstm':
self.lstm.reset_parameters()
self.att.reset_parameters()

[docs]    def forward(self, feat_list):
r"""

Description
-----------
Aggregate output representations across multiple GNN layers.

Parameters
----------
feat_list : list[Tensor]
feat_list[i] is the output representations of a GNN layer.

Returns
-------
Tensor
The aggregated representations.
"""
if self.mode == 'cat':
return th.cat(feat_list, dim=-1)
elif self.mode == 'max':
return th.stack(feat_list, dim=-1).max(dim=-1)
else:
# LSTM
stacked_feat_list = th.stack(feat_list, dim=1) # (N, num_layers, in_feats)
alpha, _ = self.lstm(stacked_feat_list)
alpha = self.att(alpha).squeeze(-1)            # (N, num_layers)
alpha = th.softmax(alpha, dim=-1)
return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1)

[docs]class LabelPropagation(nn.Module):
r"""Label Propagation from Learning from Labeled and Unlabeled Data with Label
Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>__

.. math::

\mathbf{Y}^{(t+1)} = \alpha \tilde{A} \mathbf{Y}^{(t)} + (1 - \alpha) \mathbf{Y}^{(0)}

where unlabeled data is initially set to zero and inferred from labeled data via
propagation. :math:\alpha is a weight parameter for balancing between updated labels
and initial labels. :math:\tilde{A} denotes the normalized adjacency matrix.

Parameters
----------
k: int
The number of propagation steps.
alpha : float
The :math:\alpha coefficient in range [0, 1].
norm_type : str, optional
The type of normalization applied to the adjacency matrix, must be one of the
following choices:

* row: row-normalized adjacency as :math:D^{-1}A

* sym: symmetrically normalized adjacency as :math:D^{-1/2}AD^{-1/2}

Default: 'sym'.
clamp : bool, optional
A bool flag to indicate whether to clamp the labels to [0, 1] after propagation.
Default: True.
normalize: bool, optional
A bool flag to indicate whether to apply row-normalization after propagation.
Default: False.
reset : bool, optional
A bool flag to indicate whether to reset the known labels after each
propagation step. Default: False.

Examples
--------
>>> import torch
>>> import dgl
>>> from dgl.nn import LabelPropagation

>>> label_propagation = LabelPropagation(k=5, alpha=0.5, clamp=False, normalize=True)
>>> g = dgl.rand_graph(5, 10)
>>> labels = torch.tensor([0, 2, 1, 3, 0]).long()
>>> mask = torch.tensor([0, 1, 1, 1, 0]).bool()
>>> new_labels = label_propagation(g, labels, mask)
"""
def __init__(self, k, alpha, norm_type='sym', clamp=True, normalize=False, reset=False):
super(LabelPropagation, self).__init__()
self.k = k
self.alpha = alpha
self.norm_type = norm_type
self.clamp = clamp
self.normalize = normalize
self.reset = reset

[docs]    def forward(self, g, labels, mask=None):
r"""Compute the label propagation process.

Parameters
----------
g : DGLGraph
The input graph.
labels : torch.Tensor
The input node labels. There are three cases supported.

* A LongTensor of shape :math:(N, 1) or :math:(N,) for node class labels in
multiclass classification, where :math:N is the number of nodes.
* A LongTensor of shape :math:(N, C) for one-hot encoding of node class labels
in multiclass classification, where :math:C is the number of classes.
* A LongTensor of shape :math:(N, L) for node labels in multilabel binary
classification, where :math:L is the number of labels.
The bool indicators of shape :math:(N,) with True denoting labeled nodes.
Default: None, indicating all nodes are labeled.

Returns
-------
torch.Tensor
The propagated node labels of shape :math:(N, D) with float type, where :math:D
is the number of classes or labels.
"""
with g.local_scope():
# multi-label / multi-class
if len(labels.size()) > 1 and labels.size(1) > 1:
labels = labels.to(th.float32)
# single-label multi-class
else:
labels = F.one_hot(labels.view(-1)).to(th.float32)

y = labels
y = th.zeros_like(labels)

init = (1 - self.alpha) * y
in_degs = g.in_degrees().float().clamp(min=1)
out_degs = g.out_degrees().float().clamp(min=1)
if self.norm_type == 'sym':
norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1)
norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1)
elif self.norm_type == 'row':
norm_i = th.pow(in_degs, -1.).to(labels.device).unsqueeze(1)
else:
raise ValueError(f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}")

for _ in range(self.k):
g.ndata['h'] = y * norm_j if self.norm_type == 'sym' else y
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = init + self.alpha * g.ndata['h'] * norm_i

if self.clamp:
y = y.clamp_(0., 1.)
if self.normalize:
y = F.normalize(y, p=1)
if self.reset: