# Source code for dgl.nn.pytorch.conv.relgraphconv

"""Torch Module for Relational graph convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import functools
import numpy as np
import torch as th
from torch import nn

from .... import function as fn
from .. import utils
from ....base import DGLError
from .... import edge_subgraph

[docs]class RelGraphConv(nn.Module):
r"""Relational graph convolution layer.

Relational graph convolution is introduced in "Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>__"
and can be described in DGL as below:

.. math::

h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
\sum_{j\in\mathcal{N}^r(i)}e_{j,i}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})

where :math:\mathcal{N}^r(i) is the neighbor set of node :math:i w.r.t. relation
:math:r. :math:e_{j,i} is the normalizer. :math:\sigma is an activation
function. :math:W_0 is the self-loop weight.

The basis regularization decomposes :math:W_r by:

.. math::

W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}

where :math:B is the number of bases, :math:V_b^{(l)} are linearly combined
with coefficients :math:a_{rb}^{(l)}.

The block-diagonal-decomposition regularization decomposes :math:W_r into :math:B
number of block diagonal matrices. We refer :math:B as the number of bases.

The block regularization decomposes :math:W_r by:

.. math::

W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)}

where :math:B is the number of bases, :math:Q_{rb}^{(l)} are block
bases with shape :math:R^{(d^{(l+1)}/B)*(d^{l}/B)}.

Parameters
----------
in_feat : int
Input feature size; i.e, the number of dimensions of :math:h_j^{(l)}.
out_feat : int
Output feature size; i.e., the number of dimensions of :math:h_i^{(l+1)}.
num_rels : int
Number of relations. .
regularizer : str
Which weight regularizer to use "basis" or "bdd".
"basis" is short for basis-diagonal-decomposition.
"bdd" is short for block-diagonal-decomposition.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
bias : bool, optional
True if bias is added. Default: True.
activation : callable, optional
Activation function. Default: None.
self_loop : bool, optional
True to include self loop message. Default: True.
low_mem : bool, optional
True to use low memory implementation of relation message passing function. Default: False.
This option trades speed with memory consumption, and will slowdown the forward/backward.
Turn it on when you encounter OOM problem during training or evaluation. Default: False.
dropout : float, optional
Dropout rate. Default: 0.0
layer_norm: float, optional
Add layer norm. Default: False

Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import RelGraphConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)
>>> conv.weight.shape
torch.Size([2, 10, 2])
>>> etype = th.tensor(np.array([0,1,2,0,1,2]).astype(np.int64))
>>> res = conv(g, feat, etype)
>>> res
tensor([[ 0.3996, -2.3303],
[-0.4323, -0.1440],
[ 0.3996, -2.3303],
[ 2.1046, -2.8654],
[-0.4323, -0.1440],

>>> # One-hot input
>>> one_hot_feat = th.tensor(np.array([0,1,2,3,4,5]).astype(np.int64))
>>> res = conv(g, one_hot_feat, etype)
>>> res
tensor([[ 0.5925,  0.0985],
[-0.3953,  0.8408],
[-0.9819,  0.5284],
[-1.0085, -0.1721],
[ 0.5962,  1.2002],
"""
def __init__(self,
in_feat,
out_feat,
num_rels,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=True,
low_mem=False,
dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.low_mem = low_mem
self.layer_norm = layer_norm

if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
# message func
self.message_func = self.basis_message_func
elif regularizer == "bdd":
if in_feat % self.num_bases != 0 or out_feat % self.num_bases != 0:
raise ValueError(
'Feature size must be a multiplier of num_bases (%d).'
% self.num_bases
)
# add block diagonal weights
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases

# assuming in_feat and out_feat are both divisible by num_bases
self.weight = nn.Parameter(th.Tensor(
self.num_rels, self.num_bases * self.submat_in * self.submat_out))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# message func
self.message_func = self.bdd_message_func
else:
raise ValueError("Regularizer must be either 'basis' or 'bdd'")

# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)

# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True)

# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))

self.dropout = nn.Dropout(dropout)

def basis_message_func(self, edges, etypes):
"""Message function for basis regularizer.

Parameters
----------
edges : dgl.EdgeBatch
Input to DGL message UDF.
etypes : torch.Tensor or list[int]
Edge type data. Could be either:

* An :math:(|E|,) dense tensor. Each element corresponds to the edge's type ID.
Preferred format if lowmem == False.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if lowmem == True.
"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight

h = edges.src['h']
device = h.device

if h.dtype == th.int64 and h.ndim == 1:
# Each element is the node's ID. Use index select: weight[etypes, h, :]
# The following is a faster version of it.
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
idim = weight.shape[1]
weight = weight.view(-1, weight.shape[2])
flatidx = etypes * idim + h
msg = weight.index_select(0, flatidx)
elif self.low_mem:
# A more memory-friendly implementation.
# Calculate msg @ W_r before put msg into edge.
assert isinstance(etypes, list)
h_t = th.split(h, etypes)
msg = []
for etype in range(self.num_rels):
if h_t[etype].shape[0] == 0:
continue
msg.append(th.matmul(h_t[etype], weight[etype]))
msg = th.cat(msg)
else:
# Use batched matmult
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = weight.index_select(0, etypes)
msg = th.bmm(h.unsqueeze(1), weight).squeeze(1)

if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}

def bdd_message_func(self, edges, etypes):
"""Message function for block-diagonal-decomposition regularizer.

Parameters
----------
edges : dgl.EdgeBatch
Input to DGL message UDF.
etypes : torch.Tensor or list[int]
Edge type data. Could be either:

* An :math:(|E|,) dense tensor. Each element corresponds to the edge's type ID.
Preferred format if lowmem == False.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if lowmem == True.
"""
h = edges.src['h']
device = h.device

if h.dtype == th.int64 and h.ndim == 1:
raise TypeError('Block decomposition does not allow integer ID feature.')

if self.low_mem:
# A more memory-friendly implementation.
# Calculate msg @ W_r before put msg into edge.
assert isinstance(etypes, list)
h_t = th.split(h, etypes)
msg = []
for etype in range(self.num_rels):
if h_t[etype].shape[0] == 0:
continue
tmp_w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
tmp_h = h_t[etype].view(-1, self.num_bases, self.submat_in)
msg.append(th.einsum('abc,bcd->abd', tmp_h, tmp_w).reshape(-1, self.out_feat))
msg = th.cat(msg)
else:
# Use batched matmult
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = self.weight.index_select(0, etypes).view(
-1, self.submat_in, self.submat_out)
node = h.view(-1, 1, self.submat_in)
msg = th.bmm(node, weight).view(-1, self.out_feat)
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}

[docs]    def forward(self, g, feat, etypes, norm=None):
"""Forward computation.

Parameters
----------
g : DGLGraph
The graph.
feat : torch.Tensor
Input node features. Could be either

* :math:(|V|, D) dense tensor
* :math:(|V|,) int64 vector, representing the categorical values of each
node. It then treat the input feature as an one-hot encoding feature.
etypes : torch.Tensor or list[int]
Edge type data. Could be either

* An :math:(|E|,) dense tensor. Each element corresponds to the edge's type ID.
Preferred format if lowmem == False.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if lowmem == True.
norm : torch.Tensor, optional
Edge normalizer. Could be either

* An :math:(|E|, 1) tensor storing the normalizer on each edge.

Returns
-------
torch.Tensor
New node features.

Notes
-----
Under the low_mem mode, DGL will sort the graph based on the edge types
and compute message passing one type at a time. DGL recommends sorts the
graph beforehand (and cache it if possible) and provides the integer list
format to the etypes argument. Use DGL's :func:~dgl.to_homogeneous API
to get a sorted homogeneous graph from a heterogeneous graph. Pass return_count=True
to it to get the etypes in integer list.
"""
if isinstance(etypes, th.Tensor):
if len(etypes) != g.num_edges():
raise DGLError('"etypes" tensor must have length equal to the number of edges'
' in the graph. But got {} and {}.'.format(
len(etypes), g.num_edges()))
if self.low_mem and not (feat.dtype == th.int64 and feat.ndim == 1):
# Low-mem optimization is not enabled for node ID input. When enabled,
# it first sorts the graph based on the edge types (the sorting will not
# change the node IDs). It then converts the etypes tensor to an integer
# list, where each element is the number of edges of the type.
# Sort the graph based on the etypes
sorted_etypes, index = th.sort(etypes)
g = edge_subgraph(g, index, relabel_nodes=False)
# Create a new etypes to be an integer list of number of edges.
pos = _searchsorted(sorted_etypes, th.arange(self.num_rels, device=g.device))
num = th.tensor([len(etypes)], device=g.device)
etypes = (th.cat([pos[1:], num]) - pos).tolist()
if norm is not None:
norm = norm[index]

with g.local_scope():
g.srcdata['h'] = feat
if norm is not None:
g.edata['norm'] = norm
if self.self_loop:
loop_message = utils.matmul_maybe_select(feat[:g.number_of_dst_nodes()],
self.loop_weight)
# message passing
g.update_all(functools.partial(self.message_func, etypes=etypes),
fn.sum(msg='msg', out='h'))
# apply bias and activation
node_repr = g.dstdata['h']
if self.layer_norm:
node_repr = self.layer_norm_weight(node_repr)
if self.bias:
node_repr = node_repr + self.h_bias
if self.self_loop:
node_repr = node_repr + loop_message
if self.activation:
node_repr = self.activation(node_repr)
node_repr = self.dropout(node_repr)
return node_repr

_TORCH_HAS_SEARCHSORTED = getattr(th, 'searchsorted', None)

def _searchsorted(sorted_sequence, values):
# searchsorted is introduced to PyTorch in 1.6.0
if _TORCH_HAS_SEARCHSORTED:
return th.searchsorted(sorted_sequence, values)
else:
device = values.device
return th.from_numpy(np.searchsorted(sorted_sequence.cpu().numpy(),
values.cpu().numpy())).to(device)