NN Modules (PyTorch)¶
Conv Layers¶
Torch modules for graph convolutions.
GraphConv¶
-
class
dgl.nn.pytorch.conv.
GraphConv
(in_feats, out_feats, norm='both', weight=True, bias=True, activation=None, allow_zero_in_degree=False)[source]¶ Bases:
torch.nn.modules.module.Module
Graph convolution was introduced in GCN and mathematically is defined as follows:
\[h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)})\]where \(\mathcal{N}(i)\) is the set of neighbors of node \(i\), \(c_{ij}\) is the product of the square root of node degrees (i.e., \(c_{ij} = \sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}\)), and \(\sigma\) is an activation function.
- Parameters
in_feats (int) – Input feature size; i.e, the number of dimensions of \(h_j^{(l)}\).
out_feats (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(l+1)}\).
norm (str, optional) – How to apply the normalizer. If is ‘right’, divide the aggregated messages by each node’s in-degrees, which is equivalent to averaging the received messages. If is ‘none’, no normalization is applied. Default is ‘both’, where the \(c_{ij}\) in the paper is applied.
weight (bool, optional) – If True, apply a linear layer. Otherwise, aggregating the messages without a weight matrix.
bias (bool, optional) – If True, adds a learnable bias to the output. Default:
True
.activation (callable activation function/layer or None, optional) – If not None, applies an activation function to the updated node features. Default:
None
.allow_zero_in_degree (bool, optional) – If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no message will be passed to those nodes. This is harmful for some applications causing silent performance regression. This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting
True
, it will suppress the check and let the users handle it by themselves. Default:False
.
-
weight
¶ The learnable weight tensor.
- Type
torch.Tensor
-
bias
¶ The learnable bias tensor.
- Type
torch.Tensor
Note
Zero in-degree nodes will lead to invalid output value. This is because no message will be passed to those nodes, the aggregation function will be appied on empty input. A common practice to avoid this is to add a self-loop for each node in the graph if it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g)
Calling
add_self_loop
will not work for some graphs, for example, heterogeneous graph since the edge type can not be decided for self_loop edges. Setallow_zero_in_degree
toTrue
for those cases to unblock the code and handle zere-in-degree nodes manually. A common practise to handle this is to filter out the nodes with zere-in-degree when use after conv.Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import GraphConv
>>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> conv = GraphConv(10, 2, norm='both', weight=True, bias=True) >>> res = conv(g, feat) >>> print(res) tensor([[ 1.3326, -0.2797], [ 1.4673, -0.3080], [ 1.3326, -0.2797], [ 1.6871, -0.3541], [ 1.7711, -0.3717], [ 1.0375, -0.2178]], grad_fn=<AddBackward0>) >>> # allow_zero_in_degree example >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> conv = GraphConv(10, 2, norm='both', weight=True, bias=True, allow_zero_in_degree=True) >>> res = conv(g, feat) >>> print(res) tensor([[-0.2473, -0.4631], [-0.3497, -0.6549], [-0.3497, -0.6549], [-0.4221, -0.7905], [-0.3497, -0.6549], [ 0.0000, 0.0000]], grad_fn=<AddBackward0>)
>>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.bipartite((u, v)) >>> u_fea = th.rand(2, 5) >>> v_fea = th.rand(4, 5) >>> conv = GraphConv(5, 2, norm='both', weight=True, bias=True) >>> res = conv(g, (u_fea, v_fea)) >>> res tensor([[-0.2994, 0.6106], [-0.4482, 0.5540], [-0.5287, 0.8235], [-0.2994, 0.6106]], grad_fn=<AddBackward0>)
-
forward
(graph, feat, weight=None)[source]¶ Compute graph convolution.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor or pair of torch.Tensor) – If a torch.Tensor is given, it represents the input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes. If a pair of torch.Tensor is given, which is the case for bipartite graph, the pair must contain two tensors of shape \((N_{in}, D_{in_{src}})\) and \((N_{out}, D_{in_{dst}})\).
weight (torch.Tensor, optional) – Optional external weight tensor.
- Returns
The output feature
- Return type
torch.Tensor
- Raises
DGLError – Case 1: If there are 0-in-degree nodes in the input graph, it will raise DGLError since no message will be passed to those nodes. This will cause invalid output. The error can be ignored by setting
allow_zero_in_degree
parameter toTrue
. Case 2: External weight is provided while at the same time the module has defined its own weight parameter.
Note
Input shape: \((N, *, \text{in_feats})\) where * means any number of additional dimensions, \(N\) is the number of nodes.
Output shape: \((N, *, \text{out_feats})\) where all but the last dimension are the same shape as the input.
Weight shape: \((\text{in_feats}, \text{out_feats})\).
-
reset_parameters
()[source]¶ Reinitialize learnable parameters.
Note
The model parameters are initialized as in the original implementation where the weight \(W^{(l)}\) is initialized using Glorot uniform initialization and the bias is initialized to be zero.
RelGraphConv¶
-
class
dgl.nn.pytorch.conv.
RelGraphConv
(in_feat, out_feat, num_rels, regularizer='basis', num_bases=None, bias=True, activation=None, self_loop=True, low_mem=False, dropout=0.0, layer_norm=False)[source]¶ Bases:
torch.nn.modules.module.Module
Relational graph convolution layer.
Relational graph convolution is introduced in “Modeling Relational Data with Graph Convolutional Networks” and can be described as below:
\[h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}} \sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})\]where \(\mathcal{N}^r(i)\) is the neighbor set of node \(i\) w.r.t. relation \(r\). \(c_{i,r}\) is the normalizer equal to \(|\mathcal{N}^r(i)|\). \(\sigma\) is an activation function. \(W_0\) is the self-loop weight.
The basis regularization decomposes \(W_r\) by:
\[W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}\]where \(B\) is the number of bases, \(V_b^{(l)}\) are linearly combined with coefficients \(a_{rb}^{(l)}\).
The block-diagonal-decomposition regularization decomposes \(W_r\) into \(B\) number of block diagonal matrices. We refer \(B\) as the number of bases.
The block regularization decomposes \(W_r\) by:
\[W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)}\]where \(B\) is the number of bases, \(Q_{rb}^{(l)}\) are block bases with shape \(R^{(d^{(l+1)}/B)*(d^{l}/B)}\).
- Parameters
in_feat (int) – Input feature size; i.e, the number of dimensions of \(h_j^{(l)}\).
out_feat (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(l+1)}\).
num_rels (int) – Number of relations. .
regularizer (str) – Which weight regularizer to use “basis” or “bdd”. “basis” is short for basis-diagonal-decomposition. “bdd” is short for block-diagonal-decomposition.
num_bases (int, optional) – Number of bases. If is none, use number of relations. Default:
None
.bias (bool, optional) – True if bias is added. Default:
True
.activation (callable, optional) – Activation function. Default:
None
.self_loop (bool, optional) – True to include self loop message. Default:
True
.low_mem (bool, optional) – True to use low memory implementation of relation message passing function. Default: False. This option trades speed with memory consumption, and will slowdown the forward/backward. Turn it on when you encounter OOM problem during training or evaluation. Default:
False
.dropout (float, optional) – Dropout rate. Default:
0.0
layer_norm (float, optional) – Add layer norm. Default:
False
Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import RelGraphConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2) >>> conv.weight.shape torch.Size([2, 10, 2]) >>> etype = th.tensor(np.array([0,1,2,0,1,2]).astype(np.int64)) >>> res = conv(g, feat, etype) >>> res tensor([[ 0.3996, -2.3303], [-0.4323, -0.1440], [ 0.3996, -2.3303], [ 2.1046, -2.8654], [-0.4323, -0.1440], [-0.1309, -1.0000]], grad_fn=<AddBackward0>)
>>> # One-hot input >>> one_hot_feat = th.tensor(np.array([0,1,2,3,4,5]).astype(np.int64)) >>> res = conv(g, one_hot_feat, etype) >>> res tensor([[ 0.5925, 0.0985], [-0.3953, 0.8408], [-0.9819, 0.5284], [-1.0085, -0.1721], [ 0.5962, 1.2002], [ 0.0365, -0.3532]], grad_fn=<AddBackward0>)
-
forward
(g, feat, etypes, norm=None)[source]¶ Forward computation
- Parameters
g (DGLGraph) – The graph.
feat (torch.Tensor) –
Input node features. Could be either
\((|V|, D)\) dense tensor
\((|V|,)\) int64 vector, representing the categorical values of each node. It then treat the input feature as an one-hot encoding feature.
etypes (torch.Tensor) – Edge type tensor. Shape: \((|E|,)\)
norm (torch.Tensor) – Optional edge normalizer tensor. Shape: \((|E|, 1)\).
- Returns
New node features.
- Return type
torch.Tensor
TAGConv¶
-
class
dgl.nn.pytorch.conv.
TAGConv
(in_feats, out_feats, k=2, bias=True, activation=None)[source]¶ Bases:
torch.nn.modules.module.Module
Topology Adaptive Graph Convolutional layer from paper Topology Adaptive Graph Convolutional Networks.
\[H^{K} = {\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\Theta}_{k},\]where \(A\) denotes the adjacency matrix, \(D_{ii} = \sum_{j=0} A_{ij}\) its diagonal degree matrix, \({\Theta}_{k}\) denotes the linear weights to sum the results of different hops together.
- Parameters
in_feats (int) – Input feature size. i.e, the number of dimensions of \(X\).
out_feats (int) – Output feature size. i.e, the number of dimensions of \(H^{K}\).
k (int, optional) – Number of hops \(K\). Default:
2
.bias (bool, optional) – If True, adds a learnable bias to the output. Default:
True
.activation (callable activation function/layer or None, optional) – If not None, applies an activation function to the updated node features. Default:
None
.
-
lin
¶ The learnable linear module.
- Type
torch.Module
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import TAGConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = TAGConv(10, 2, k=2) >>> res = conv(g, feat) >>> res tensor([[ 0.5490, -1.6373], [ 0.5490, -1.6373], [ 0.5490, -1.6373], [ 0.5513, -1.8208], [ 0.5215, -1.6044], [ 0.3304, -1.9927]], grad_fn=<AddmmBackward>)
-
forward
(graph, feat)[source]¶ Compute topology adaptive graph convolution.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – The input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
GATConv¶
-
class
dgl.nn.pytorch.conv.
GATConv
(in_feats, out_feats, num_heads, feat_drop=0.0, attn_drop=0.0, negative_slope=0.2, residual=False, activation=None, allow_zero_in_degree=False)[source]¶ Bases:
torch.nn.modules.module.Module
Apply Graph Attention Network over an input signal.
\[h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}\]where \(\alpha_{ij}\) is the attention score bewteen node \(i\) and node \(j\):
\[ \begin{align}\begin{aligned}\alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})\\e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)\end{aligned}\end{align} \]- Parameters
in_feats (int, or pair of ints) – Input feature size; i.e, the number of dimensions of \(h_i^{(l)}\). GATConv can be applied on homogeneous graph and unidirectional bipartite graph. If the layer is to be applied to a unidirectional bipartite graph,
in_feats
specifies the input feature size on both the source and destination nodes. If a scalar is given, the source and destination node feature size would take the same value.out_feats (int) – Output feature size; i.e, the number of dimensions of \(h_i^{(l+1)}\).
num_heads (int) – Number of heads in Multi-Head Attention.
feat_drop (float, optional) – Dropout rate on feature. Defaults:
0
.attn_drop (float, optional) – Dropout rate on attention weight. Defaults:
0
.negative_slope (float, optional) – LeakyReLU angle of negative slope. Defaults:
0.2
.residual (bool, optional) – If True, use residual connection. Defaults:
False
.activation (callable activation function/layer or None, optional.) – If not None, applies an activation function to the updated node features. Default:
None
.allow_zero_in_degree (bool, optional) – If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no message will be passed to those nodes. This is harmful for some applications causing silent performance regression. This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting
True
, it will suppress the check and let the users handle it by themselves. Defaults:False
.
Note
Zero in-degree nodes will lead to invalid output value. This is because no message will be passed to those nodes, the aggregation function will be appied on empty input. A common practice to avoid this is to add a self-loop for each node in the graph if it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g)
Calling
add_self_loop
will not work for some graphs, for example, heterogeneous graph since the edge type can not be decided for self_loop edges. Setallow_zero_in_degree
toTrue
for those cases to unblock the code and handle zere-in-degree nodes manually. A common practise to handle this is to filter out the nodes with zere-in-degree when use after conv.Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import GATConv
>>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> gatconv = GATConv(10, 2, num_heads=3) >>> res = gatconv(g, feat) >>> res tensor([[[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]], [[ 3.4570, 1.8634], [ 1.3805, -0.0762], [ 1.0390, -1.1479]]], grad_fn=<BinaryReduceBackward>)
>>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.bipartite((u, v)) >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32)) >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32)) >>> gatconv = GATConv((5,10), 2, 3) >>> res = gatconv(g, (u_feat, v_feat)) >>> res tensor([[[-0.6066, 1.0268], [-0.5945, -0.4801], [ 0.1594, 0.3825]], [[ 0.0268, 1.0783], [ 0.5041, -1.3025], [ 0.6568, 0.7048]], [[-0.2688, 1.0543], [-0.0315, -0.9016], [ 0.3943, 0.5347]], [[-0.6066, 1.0268], [-0.5945, -0.4801], [ 0.1594, 0.3825]]], grad_fn=<BinaryReduceBackward>)
-
forward
(graph, feat)[source]¶ Compute graph attention network layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor or pair of torch.Tensor) – If a torch.Tensor is given, the input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape \((N_{in}, D_{in_{src}})\) and \((N_{out}, D_{in_{dst}})\).
- Returns
The output feature of shape \((N, H, D_{out})\) where \(H\) is the number of heads, and \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
- Raises
DGLError – If there are 0-in-degree nodes in the input graph, it will raise DGLError since no message will be passed to those nodes. This will cause invalid output. The error can be ignored by setting
allow_zero_in_degree
parameter toTrue
.
EdgeConv¶
-
class
dgl.nn.pytorch.conv.
EdgeConv
(in_feat, out_feat, batch_norm=False, allow_zero_in_degree=False)[source]¶ Bases:
torch.nn.modules.module.Module
EdgeConv layer.
Introduced in “Dynamic Graph CNN for Learning on Point Clouds”. Can be described as follows:
\[h_i^{(l+1)} = \max_{j \in \mathcal{N}(i)} \mathrm{ReLU}( \Theta \cdot (h_j^{(l)} - h_i^{(l)}) + \Phi \cdot h_i^{(l)})\]where \(\mathcal{N}(i)\) is the neighbor of \(i\). \(\Theta\) and \(\Phi\) are linear layers.
- Parameters
in_feat (int) – Input feature size; i.e, the number of dimensions of \(h_j^{(l)}\).
out_feat (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(l+1)}\).
batch_norm (bool) – Whether to include batch normalization on messages. Default:
False
.allow_zero_in_degree (bool, optional) – If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no message will be passed to those nodes. This is harmful for some applications causing silent performance regression. This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting
True
, it will suppress the check and let the users handle it by themselves. Default:False
.
Note
Zero in-degree nodes will lead to invalid output value. This is because no message will be passed to those nodes, the aggregation function will be appied on empty input. A common practice to avoid this is to add a self-loop for each node in the graph if it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g)
Calling
add_self_loop
will not work for some graphs, for example, heterogeneous graph since the edge type can not be decided for self_loop edges. Setallow_zero_in_degree
toTrue
for those cases to unblock the code and handle zere-in-degree nodes manually. A common practise to handle this is to filter out the nodes with zere-in-degree when use after conv.Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import EdgeConv
>>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> conv = EdgeConv(10, 2) >>> res = conv(g, feat) >>> res tensor([[-0.2347, 0.5849], [-0.2347, 0.5849], [-0.2347, 0.5849], [-0.2347, 0.5849], [-0.2347, 0.5849], [-0.2347, 0.5849]], grad_fn=<CopyReduceBackward>)
>>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.bipartite((u, v)) >>> u_fea = th.rand(2, 5) >>> v_fea = th.rand(4, 5) >>> conv = EdgeConv(5, 2, 3) >>> res = conv(g, (u_fea, v_fea)) >>> res tensor([[ 1.6375, 0.2085], [-1.1925, -1.2852], [ 0.2101, 1.3466], [ 0.2342, -0.9868]], grad_fn=<CopyReduceBackward>)
-
forward
(g, feat)[source]¶ Forward computation
- Parameters
g (DGLGraph) – The graph.
feat (Tensor or pair of tensors) –
\((N, D)\) where \(N\) is the number of nodes and \(D\) is the number of feature dimensions.
If a pair of tensors is given, the graph must be a uni-bipartite graph with only one edge type, and the two tensors must have the same dimensionality on all except the first axis.
- Returns
New node features.
- Return type
torch.Tensor
- Raises
DGLError – If there are 0-in-degree nodes in the input graph, it will raise DGLError since no message will be passed to those nodes. This will cause invalid output. The error can be ignored by setting
allow_zero_in_degree
parameter toTrue
.
SAGEConv¶
-
class
dgl.nn.pytorch.conv.
SAGEConv
(in_feats, out_feats, aggregator_type, feat_drop=0.0, bias=True, norm=None, activation=None)[source]¶ Bases:
torch.nn.modules.module.Module
GraphSAGE layer from paper Inductive Representation Learning on Large Graphs.
\[ \begin{align}\begin{aligned}h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate} \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)\\h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat} (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right)\\h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{l})\end{aligned}\end{align} \]- Parameters
in_feats (int, or pair of ints) –
Input feature size; i.e, the number of dimensions of \(h_i^{(l)}\).
SAGEConv can be applied on homogeneous graph and unidirectional bipartite graph. If the layer applies on a unidirectional bipartite graph,
in_feats
specifies the input feature size on both the source and destination nodes. If a scalar is given, the source and destination node feature size would take the same value.If aggregator type is
gcn
, the feature size of source and destination nodes are required to be the same.out_feats (int) – Output feature size; i.e, the number of dimensions of \(h_i^{(l+1)}\).
feat_drop (float) – Dropout rate on features, default:
0
.aggregator_type (str) – Aggregator type to use (
mean
,gcn
,pool
,lstm
).bias (bool) – If True, adds a learnable bias to the output. Default:
True
.norm (callable activation function/layer or None, optional) – If not None, applies normalization to the updated node features.
activation (callable activation function/layer or None, optional) – If not None, applies an activation function to the updated node features. Default:
None
.
Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import SAGEConv
>>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> conv = SAGEConv(10, 2, 'pool') >>> res = conv(g, feat) >>> res tensor([[-1.0888, -2.1099], [-1.0888, -2.1099], [-1.0888, -2.1099], [-1.0888, -2.1099], [-1.0888, -2.1099], [-1.0888, -2.1099]], grad_fn=<AddBackward0>)
>>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.bipartite((u, v)) >>> u_fea = th.rand(2, 5) >>> v_fea = th.rand(4, 10) >>> conv = SAGEConv((5, 10), 2, 'mean') >>> res = conv(g, (u_fea, v_fea)) >>> res tensor([[ 0.3163, 3.1166], [ 0.3866, 2.5398], [ 0.5873, 1.6597], [-0.2502, 2.8068]], grad_fn=<AddBackward0>)
-
forward
(graph, feat)[source]¶ Compute GraphSAGE layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor or pair of torch.Tensor) – If a torch.Tensor is given, it represents the input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape \((N_{in}, D_{in_{src}})\) and \((N_{out}, D_{in_{dst}})\).
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
SGConv¶
-
class
dgl.nn.pytorch.conv.
SGConv
(in_feats, out_feats, k=1, cached=False, bias=True, norm=None, allow_zero_in_degree=False)[source]¶ Bases:
torch.nn.modules.module.Module
Simplifying Graph Convolution layer from paper Simplifying Graph Convolutional Networks.
\[H^{K} = (\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2})^K X \Theta\]where \(\tilde{A}\) is \(A\) + \(I\). Thus the graph input is expected to have self-loop edges added.
- Parameters
in_feats (int) – Number of input features; i.e, the number of dimensions of \(X\).
out_feats (int) – Number of output features; i.e, the number of dimensions of \(H^{K}\).
k (int) – Number of hops \(K\). Defaults:
1
.cached (bool) –
If True, the module would cache
\[(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}})^K X\Theta\]at the first forward call. This parameter should only be set to
True
in Transductive Learning setting.bias (bool) – If True, adds a learnable bias to the output. Default:
True
.norm (callable activation function/layer or None, optional) – If not None, applies normalization to the updated node features. Default:
False
.allow_zero_in_degree (bool, optional) – If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no message will be passed to those nodes. This is harmful for some applications causing silent performance regression. This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting
True
, it will suppress the check and let the users handle it by themselves. Default:False
.
Note
Zero in-degree nodes will lead to invalid output value. This is because no message will be passed to those nodes, the aggregation function will be appied on empty input. A common practice to avoid this is to add a self-loop for each node in the graph if it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g)
Calling
add_self_loop
will not work for some graphs, for example, heterogeneous graph since the edge type can not be decided for self_loop edges. Setallow_zero_in_degree
toTrue
for those cases to unblock the code and handle zere-in-degree nodes manually. A common practise to handle this is to filter out the nodes with zere-in-degree when use after conv.Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import SGConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> conv = SGConv(10, 2, k=2, cached=True) >>> res = conv(g, feat) >>> res tensor([[-1.9441, -0.9343], [-1.9441, -0.9343], [-1.9441, -0.9343], [-2.7709, -1.3316], [-1.9297, -0.9273], [-1.9441, -0.9343]], grad_fn=<AddmmBackward>)
-
forward
(graph, feat)[source]¶ Compute Simplifying Graph Convolution layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – The input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
- Raises
DGLError – If there are 0-in-degree nodes in the input graph, it will raise DGLError since no message will be passed to those nodes. This will cause invalid output. The error can be ignored by setting
allow_zero_in_degree
parameter toTrue
.
Note
If
cache
is set to True,feat
andgraph
should not change during training, or you will get wrong results.
APPNPConv¶
-
class
dgl.nn.pytorch.conv.
APPNPConv
(k, alpha, edge_drop=0.0)[source]¶ Bases:
torch.nn.modules.module.Module
Approximate Personalized Propagation of Neural Predictions layer from paper Predict then Propagate: Graph Neural Networks meet Personalized PageRank.
\[ \begin{align}\begin{aligned}H^{0} &= X\\H^{l+1} &= (1-\alpha)\left(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{l}\right) + \alpha H^{0}\end{aligned}\end{align} \]where \(\tilde{A}\) is \(A\) + \(I\).
- Parameters
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import APPNPConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = APPNPConv(k=3, alpha=0.5) >>> res = conv(g, feat) >>> res tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0303, 1.0303, 1.0303, 1.0303, 1.0303, 1.0303, 1.0303, 1.0303, 1.0303, 1.0303], [0.8643, 0.8643, 0.8643, 0.8643, 0.8643, 0.8643, 0.8643, 0.8643, 0.8643, 0.8643], [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])
-
forward
(graph, feat)[source]¶ Compute APPNP layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – The input feature of shape \((N, *)\). \(N\) is the number of nodes, and \(*\) could be of any shape.
- Returns
The output feature of shape \((N, *)\) where \(*\) should be the same as input shape.
- Return type
torch.Tensor
GINConv¶
-
class
dgl.nn.pytorch.conv.
GINConv
(apply_func, aggregator_type, init_eps=0, learn_eps=False)[source]¶ Bases:
torch.nn.modules.module.Module
Graph Isomorphism Network layer from paper How Powerful are Graph Neural Networks?.
\[h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} + \mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i) \right\}\right)\right)\]- Parameters
apply_func (callable activation function/layer or None) – If not None, apply this function to the updated node feature, the \(f_\Theta\) in the formula.
aggregator_type (str) – Aggregator type to use (
sum
,max
ormean
).init_eps (float, optional) – Initial \(\epsilon\) value, default:
0
.learn_eps (bool, optional) – If True, \(\epsilon\) will be a learnable parameter. Default:
False
.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import GINConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> lin = th.nn.Linear(10, 10) >>> conv = GINConv(lin, 'max') >>> res = conv(g, feat) >>> res tensor([[-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855, 0.8843, -0.8764], [-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855, 0.8843, -0.8764], [-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855, 0.8843, -0.8764], [-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855, 0.8843, -0.8764], [-0.4821, 0.0207, -0.7665, 0.5721, -0.4682, -0.2134, -0.5236, 1.2855, 0.8843, -0.8764], [-0.1804, 0.0758, -0.5159, 0.3569, -0.1408, -0.1395, -0.2387, 0.7773, 0.5266, -0.4465]], grad_fn=<AddmmBackward>)
-
forward
(graph, feat)[source]¶ Compute Graph Isomorphism Network layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor or pair of torch.Tensor) – If a torch.Tensor is given, the input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape \((N_{in}, D_{in})\) and \((N_{out}, D_{in})\). If
apply_func
is not None, \(D_{in}\) should fit the input dimensionality requirement ofapply_func
.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is the output dimensionality of
apply_func
. Ifapply_func
is None, \(D_{out}\) should be the same as input dimensionality.- Return type
torch.Tensor
GatedGraphConv¶
-
class
dgl.nn.pytorch.conv.
GatedGraphConv
(in_feats, out_feats, n_steps, n_etypes, bias=True)[source]¶ Bases:
torch.nn.modules.module.Module
Gated Graph Convolution layer from paper Gated Graph Sequence Neural Networks.
\[ \begin{align}\begin{aligned}h_{i}^{0} &= [ x_i \| \mathbf{0} ]\\a_{i}^{t} &= \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t}\\h_{i}^{t+1} &= \mathrm{GRU}(a_{i}^{t}, h_{i}^{t})\end{aligned}\end{align} \]- Parameters
in_feats (int) – Input feature size; i.e, the number of dimensions of \(x_i\).
out_feats (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(t+1)}\).
n_steps (int) – Number of recurrent steps; i.e, the \(t\) in the above formula.
n_etypes (int) – Number of edge types.
bias (bool) – If True, adds a learnable bias to the output. Default:
True
.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import GatedGraphConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = GatedGraphConv(10, 10, 2, 3) >>> etype = th.tensor([0,1,2,0,1,2]) >>> res = conv(g, feat, etype) >>> res tensor([[ 0.4652, 0.4458, 0.5169, 0.4126, 0.4847, 0.2303, 0.2757, 0.7721, 0.0523, 0.0857], [ 0.0832, 0.1388, -0.5643, 0.7053, -0.2524, -0.3847, 0.7587, 0.8245, 0.9315, 0.4063], [ 0.6340, 0.4096, 0.7692, 0.2125, 0.2106, 0.4542, -0.0580, 0.3364, -0.1376, 0.4948], [ 0.5551, 0.7946, 0.6220, 0.8058, 0.5711, 0.3063, -0.5454, 0.2272, -0.6931, -0.1607], [ 0.2644, 0.2469, -0.6143, 0.6008, -0.1516, -0.3781, 0.5878, 0.7993, 0.9241, 0.1835], [ 0.6393, 0.3447, 0.3893, 0.4279, 0.3342, 0.3809, 0.0406, 0.5030, 0.1342, 0.0425]], grad_fn=<AddBackward0>)
-
forward
(graph, feat, etypes)[source]¶ Compute Gated Graph Convolution layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – The input feature of shape \((N, D_{in})\) where \(N\) is the number of nodes of the graph and \(D_{in}\) is the input feature size.
etypes (torch.LongTensor) – The edge type tensor of shape \((E,)\) where \(E\) is the number of edges of the graph.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is the output feature size.
- Return type
torch.Tensor
GMMConv¶
-
class
dgl.nn.pytorch.conv.
GMMConv
(in_feats, out_feats, dim, n_kernels, aggregator_type='sum', residual=False, bias=True, allow_zero_in_degree=False)[source]¶ Bases:
torch.nn.modules.module.Module
The Gaussian Mixture Model Convolution layer from Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs.
\[ \begin{align}\begin{aligned}u_{ij} &= f(x_i, x_j), x_j \in \mathcal{N}(i)\\w_k(u) &= \exp\left(-\frac{1}{2}(u-\mu_k)^T \Sigma_k^{-1} (u - \mu_k)\right)\\h_i^{l+1} &= \mathrm{aggregate}\left(\left\{\frac{1}{K} \sum_{k}^{K} w_k(u_{ij}), \forall j\in \mathcal{N}(i)\right\}\right)\end{aligned}\end{align} \]where \(u\) denotes the pseudo-coordinates between a vertex and one of its neighbor, computed using function \(f\), \(\Sigma_k^{-1}\) and \(\mu_k\) are learnable parameters representing the covariance matrix and mean vector of a Gaussian kernel.
- Parameters
in_feats (int) – Number of input features; i.e., the number of dimensions of \(x_i\).
out_feats (int) – Number of output features; i.e., the number of dimensions of \(h_i^{(l+1)}\).
dim (int) – Dimensionality of pseudo-coordinte; i.e, the number of dimensions of \(u_{ij}\).
n_kernels (int) – Number of kernels \(K\).
aggregator_type (str) – Aggregator type (
sum
,mean
,max
). Default:sum
.residual (bool) – If True, use residual connection inside this layer. Default:
False
.bias (bool) – If True, adds a learnable bias to the output. Default:
True
.allow_zero_in_degree (bool, optional) – If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no message will be passed to those nodes. This is harmful for some applications causing silent performance regression. This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting
True
, it will suppress the check and let the users handle it by themselves. Default:False
.
Note
Zero in-degree nodes will lead to invalid output value. This is because no message will be passed to those nodes, the aggregation function will be appied on empty input. A common practice to avoid this is to add a self-loop for each node in the graph if it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g)
Calling
add_self_loop
will not work for some graphs, for example, heterogeneous graph since the edge type can not be decided for self_loop edges. Setallow_zero_in_degree
toTrue
for those cases to unblock the code and handle zere-in-degree nodes manually. A common practise to handle this is to filter out the nodes with zere-in-degree when use after conv.Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import GMMConv
>>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> conv = GMMConv(10, 2, 3, 2, 'mean') >>> pseudo = th.ones(12, 3) >>> res = conv(g, feat, pseudo) >>> res tensor([[-0.3462, -0.2654], [-0.3462, -0.2654], [-0.3462, -0.2654], [-0.3462, -0.2654], [-0.3462, -0.2654], [-0.3462, -0.2654]], grad_fn=<AddBackward0>)
>>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.bipartite((u, v)) >>> u_fea = th.rand(2, 5) >>> v_fea = th.rand(4, 10) >>> pseudo = th.ones(5, 3) >>> conv = GMMConv((10, 5), 2, 3, 2, 'mean') >>> res = conv(g, (u_fea, v_fea), pseudo) >>> res tensor([[-0.1107, -0.1559], [-0.1646, -0.2326], [-0.1377, -0.1943], [-0.1107, -0.1559]], grad_fn=<AddBackward0>)
-
forward
(graph, feat, pseudo)[source]¶ Compute Gaussian Mixture Model Convolution layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – If a single tensor is given, the input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes. If a pair of tensors are given, the pair must contain two tensors of shape \((N_{in}, D_{in_{src}})\) and \((N_{out}, D_{in_{dst}})\).
pseudo (torch.Tensor) – The pseudo coordinate tensor of shape \((E, D_{u})\) where \(E\) is the number of edges of the graph and \(D_{u}\) is the dimensionality of pseudo coordinate.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is the output feature size.
- Return type
torch.Tensor
- Raises
DGLError – If there are 0-in-degree nodes in the input graph, it will raise DGLError since no message will be passed to those nodes. This will cause invalid output. The error can be ignored by setting
allow_zero_in_degree
parameter toTrue
.
ChebConv¶
-
class
dgl.nn.pytorch.conv.
ChebConv
(in_feats, out_feats, k, activation=<function relu>, bias=True)[source]¶ Bases:
torch.nn.modules.module.Module
Chebyshev Spectral Graph Convolution layer from paper Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering.
\[ \begin{align}\begin{aligned}h_i^{l+1} &= \sum_{k=0}^{K-1} W^{k, l}z_i^{k, l}\\Z^{0, l} &= H^{l}\\Z^{1, l} &= \tilde{L} \cdot H^{l}\\Z^{k, l} &= 2 \cdot \tilde{L} \cdot Z^{k-1, l} - Z^{k-2, l}\\\tilde{L} &= 2\left(I - \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\right)/\lambda_{max} - I\end{aligned}\end{align} \]where \(\tilde{A}\) is \(A\) + \(I\), \(W\) is learnable weight.
- Parameters
in_feats (int) – Dimension of input features; i.e, the number of dimensions of \(h_i^{(l)}\).
out_feats (int) – Dimension of output features \(h_i^{(l+1)}\).
k (int) – Chebyshev filter size \(K\).
activation (function, optional) – Activation function. Default
ReLu
.bias (bool, optional) – If True, adds a learnable bias to the output. Default:
True
.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import ChebConv >> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = ChebConv(10, 2, 2) >>> res = conv(g, feat) >>> res tensor([[ 0.6163, -0.1809], [ 0.6163, -0.1809], [ 0.6163, -0.1809], [ 0.9698, -1.5053], [ 0.3664, 0.7556], [-0.2370, 3.0164]], grad_fn=<AddBackward0>)
-
forward
(graph, feat, lambda_max=None)[source]¶ Compute ChebNet layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – The input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes.
lambda_max (list or tensor or None, optional.) – A list(tensor) with length \(B\), stores the largest eigenvalue of the normalized laplacian of each individual graph in
graph
, where \(B\) is the batch size of the input graph. Default: None. If None, this method would compute the list by callingdgl.laplacian_lambda_max
.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
AGNNConv¶
-
class
dgl.nn.pytorch.conv.
AGNNConv
(init_beta=1.0, learn_beta=True, allow_zero_in_degree=False)[source]¶ Bases:
torch.nn.modules.module.Module
Attention-based Graph Neural Network layer from paper Attention-based Graph Neural Network for Semi-Supervised Learning.
\[H^{l+1} = P H^{l}\]where \(P\) is computed as:
\[P_{ij} = \mathrm{softmax}_i ( \beta \cdot \cos(h_i^l, h_j^l))\]where \(\beta\) is a single scalar parameter.
- Parameters
init_beta (float, optional) – The \(\beta\) in the formula, a single scalar parameter.
learn_beta (bool, optional) – If True, \(\beta\) will be learnable parameter.
allow_zero_in_degree (bool, optional) – If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no message will be passed to those nodes. This is harmful for some applications causing silent performance regression. This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting
True
, it will suppress the check and let the users handle it by themselves. Default:False
.
Note
Zero in-degree nodes will lead to invalid output value. This is because no message will be passed to those nodes, the aggregation function will be appied on empty input. A common practice to avoid this is to add a self-loop for each node in the graph if it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g)
Calling
add_self_loop
will not work for some graphs, for example, heterogeneous graph since the edge type can not be decided for self_loop edges. Setallow_zero_in_degree
toTrue
for those cases to unblock the code and handle zere-in-degree nodes manually. A common practise to handle this is to filter out the nodes with zere-in-degree when use after conv.Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import AGNNConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> conv = AGNNConv() >>> res = conv(g, feat) >>> res tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], grad_fn=<BinaryReduceBackward>)
-
forward
(graph, feat)[source]¶ Compute AGNN layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – The input feature of shape \((N, *)\) \(N\) is the number of nodes, and \(*\) could be of any shape. If a pair of torch.Tensor is given, the pair must contain two tensors of shape \((N_{in}, *)\) and \((N_{out}, *)\), the \(*\) in the later tensor must equal the previous one.
- Returns
The output feature of shape \((N, *)\) where \(*\) should be the same as input shape.
- Return type
torch.Tensor
- Raises
DGLError – If there are 0-in-degree nodes in the input graph, it will raise DGLError since no message will be passed to those nodes. This will cause invalid output. The error can be ignored by setting
allow_zero_in_degree
parameter toTrue
.
NNConv¶
-
class
dgl.nn.pytorch.conv.
NNConv
(in_feats, out_feats, edge_func, aggregator_type='mean', residual=False, bias=True)[source]¶ Bases:
torch.nn.modules.module.Module
Graph Convolution layer introduced in Neural Message Passing for Quantum Chemistry.
\[h_{i}^{l+1} = h_{i}^{l} + \mathrm{aggregate}\left(\left\{ f_\Theta (e_{ij}) \cdot h_j^{l}, j\in \mathcal{N}(i) \right\}\right)\]where \(e_{ij}\) is the edge feature, \(f_\Theta\) is a function with learnable parameters.
- Parameters
in_feats (int) – Input feature size; i.e, the number of dimensions of \(h_j^{(l)}\). NNConv can be applied on homogeneous graph and unidirectional bipartite graph. If the layer is to be applied on a unidirectional bipartite graph,
in_feats
specifies the input feature size on both the source and destination nodes. If a scalar is given, the source and destination node feature size would take the same value.out_feats (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(l+1)}\).
edge_func (callable activation function/layer) – Maps each edge feature to a vector of shape
(in_feats * out_feats)
as weight to compute messages. Also is the \(f_\Theta\) in the formula.aggregator_type (str) – Aggregator type to use (
sum
,mean
ormax
).residual (bool, optional) – If True, use residual connection. Default:
False
.bias (bool, optional) – If True, adds a learnable bias to the output. Default:
True
.
Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import NNConv
>>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> lin = th.nn.Linear(5, 20) >>> def edge_func(efeat): ... return lin(efeat) >>> efeat = th.ones(6+6, 5) >>> conv = NNConv(10, 2, edge_func, 'mean') >>> res = conv(g, feat, efeat) >>> res tensor([[-1.5243, -0.2719], [-1.5243, -0.2719], [-1.5243, -0.2719], [-1.5243, -0.2719], [-1.5243, -0.2719], [-1.5243, -0.2719]], grad_fn=<AddBackward0>)
>>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.bipartite((u, v)) >>> u_feat = th.tensor(np.random.rand(2, 10).astype(np.float32)) >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32)) >>> conv = NNConv(10, 2, edge_func, 'mean') >>> efeat = th.ones(5, 5) >>> res = conv(g, (u_feat, v_feat), efeat) >>> res tensor([[-0.6568, 0.5042], [ 0.9089, -0.5352], [ 0.1261, -0.0155], [-0.6568, 0.5042]], grad_fn=<AddBackward0>)
-
forward
(graph, feat, efeat)[source]¶ Compute MPNN Graph Convolution layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor or pair of torch.Tensor) – The input feature of shape \((N, D_{in})\) where \(N\) is the number of nodes of the graph and \(D_{in}\) is the input feature size.
efeat (torch.Tensor) – The edge feature of shape \((E, *)\), which should fit the input shape requirement of
edge_func
. \(E\) is the number of edges of the graph.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is the output feature size.
- Return type
torch.Tensor
AtomicConv¶
-
class
dgl.nn.pytorch.conv.
AtomicConv
(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling, features_to_use=None)[source]¶ Bases:
torch.nn.modules.module.Module
Atomic Convolution Layer from paper Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity.
Denoting the type of atom \(i\) by \(z_i\) and the distance between atom \(i\) and \(j\) by \(r_{ij}\).
Distance Transformation
An atomic convolution layer first transforms distances with radial filters and then perform a pooling operation.
For radial filter indexed by \(k\), it projects edge distances with
\[h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)\]If \(r_{ij} < c_k\),
\[f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),\]else,
\[f_{ij}^{k} = 0.\]Finally,
\[e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}\]Aggregation
For each type \(t\), each atom collects distance information from all neighbor atoms of type \(t\):
\[p_{i, t}^{k} = \sum_{j\in N(i)} e_{ij}^{k} * 1(z_j == t)\]Then concatenate the results for all RBF kernels and atom types.
- Parameters
interaction_cutoffs (float32 tensor of shape (K)) – \(c_k\) in the equations above. Roughly they can be considered as learnable cutoffs and two atoms are considered as connected if the distance between them is smaller than the cutoffs. K for the number of radial filters.
rbf_kernel_means (float32 tensor of shape (K)) – \(r_k\) in the equations above. K for the number of radial filters.
rbf_kernel_scaling (float32 tensor of shape (K)) – \(\gamma_k\) in the equations above. K for the number of radial filters.
features_to_use (None or float tensor of shape (T)) – In the original paper, these are atomic numbers to consider, representing the types of atoms. T for the number of types of atomic numbers. Default to None.
Note
This convolution operation is designed for molecular graphs in Chemistry, but it might be possible to extend it to more general graphs.
There seems to be an inconsistency about the definition of \(e_{ij}^{k}\) in the paper and the author’s implementation. We follow the author’s implementation. In the paper, \(e_{ij}^{k}\) was defined as \(\exp(-\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})\).
\(\gamma_{k}\), \(r_k\) and \(c_k\) are all learnable.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import AtomicConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 1) >>> edist = th.ones(6, 1) >>> interaction_cutoffs = th.ones(3).float() * 2 >>> rbf_kernel_means = th.ones(3).float() >>> rbf_kernel_scaling = th.ones(3).float() >>> conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling) >>> res = conv(g, feat, edist) >>> res tensor([[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [1.0000, 1.0000, 1.0000], [0.5000, 0.5000, 0.5000], [0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
-
forward
(graph, feat, distances)[source]¶ Apply the atomic convolution layer.
- Parameters
graph (DGLGraph) – Topology based on which message passing is performed.
feat (Float32 tensor of shape \((V, 1)\)) – Initial node features, which are atomic numbers in the paper. \(V\) for the number of nodes.
distances (Float32 tensor of shape \((E, 1)\)) – Distance between end nodes of edges. E for the number of edges.
- Returns
Updated node representations. \(V\) for the number of nodes, \(K\) for the number of radial filters, and \(T\) for the number of types of atomic numbers.
- Return type
Float32 tensor of shape \((V, K * T)\)
CFConv¶
-
class
dgl.nn.pytorch.conv.
CFConv
(node_in_feats, edge_in_feats, hidden_feats, out_feats)[source]¶ Bases:
torch.nn.modules.module.Module
CFConv in SchNet.
SchNet is introduced in SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
It combines node and edge features in message passing and updates node representations.
\[h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} h_j^{l} \circ W^{(l)}e_ij\]where \(\circ\) represents element-wise multiplication and for \(\text{SPP}\) :
\[\text{SSP}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) - \log(\text{shift})\]- Parameters
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import CFConv >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> nfeat = th.ones(6, 10) >>> efeat = th.ones(6, 5) >>> conv = CFConv(10, 5, 3, 2) >>> res = conv(g, nfeat, efeat) >>> res tensor([[-0.1209, -0.2289], [-0.1209, -0.2289], [-0.1209, -0.2289], [-0.1135, -0.2338], [-0.1209, -0.2289], [-0.1283, -0.2240]], grad_fn=<SubBackward0>)
-
forward
(g, node_feats, edge_feats)[source]¶ Performs message passing and updates node representations.
- Parameters
g (DGLGraph) – The graph.
node_feats (float32 tensor of shape (V, node_in_feats)) – Input node features, V for the number of nodes.
edge_feats (float32 tensor of shape (E, edge_in_feats)) – Input edge features, E for the number of edges.
- Returns
Updated node representations.
- Return type
float32 tensor of shape (V, out_feats)
DotGatConv¶
-
class
dgl.nn.pytorch.conv.
DotGatConv
(in_feats, out_feats, allow_zero_in_degree=False)[source]¶ Bases:
torch.nn.modules.module.Module
Apply dot product version of self attention in GCN.
\[h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i, j} h_j^{(l)}\]where \(\alpha_{ij}\) is the attention score bewteen node \(i\) and node \(j\):
\[ \begin{align}\begin{aligned}\alpha_{i, j} &= \mathrm{softmax_i}(e_{ij}^{l})\\e_{ij}^{l} &= ({W_i^{(l)} h_i^{(l)}})^T \cdot {W_j^{(l)} h_j^{(l)}}\end{aligned}\end{align} \]where \(W_i\) and \(W_j\) transform node \(i\)’s and node \(j\)’s features into the same dimension, so that when compute note features’ similarity, it can use dot-product.
- Parameters
in_feats (int, or pair of ints) – Input feature size; i.e, the number of dimensions of \(h_i^{(l)}\). DotGatConv can be applied on homogeneous graph and unidirectional bipartite graph. If the layer is to be applied to a unidirectional bipartite graph,
in_feats
specifies the input feature size on both the source and destination nodes. If a scalar is given, the source and destination node feature size would take the same value.out_feats (int) – Output feature size; i.e, the number of dimensions of \(h_i^{(l+1)}\).
allow_zero_in_degree (bool, optional) – If there are 0-in-degree nodes in the graph, output for those nodes will be invalid since no message will be passed to those nodes. This is harmful for some applications causing silent performance regression. This module will raise a DGLError if it detects 0-in-degree nodes in input graph. By setting
True
, it will suppress the check and let the users handle it by themselves. Default:False
.
Note
Zero in-degree nodes will lead to invalid output value. This is because no message will be passed to those nodes, the aggregation function will be appied on empty input. A common practice to avoid this is to add a self-loop for each node in the graph if it is homogeneous, which can be achieved by:
>>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g)
Calling
add_self_loop
will not work for some graphs, for example, heterogeneous graph since the edge type can not be decided for self_loop edges. Setallow_zero_in_degree
toTrue
for those cases to unblock the code and handle zere-in-degree nodes manually. A common practise to handle this is to filter out the nodes with zere-in-degree when use after conv.Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import DotGatConv
>>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> gatconv = DotGatConv(10, 2) >>> res = gatconv(g, feat) >>> res tensor([[-0.6958, -0.8752], [-0.6958, -0.8752], [-0.6958, -0.8752], [-0.6958, -0.8752], [-0.6958, -0.8752], [-0.6958, -0.8752]], grad_fn=<CopyReduceBackward>)
>>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.bipartite((u, v)) >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32)) >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32)) >>> gatconv = DotGatConv((5,10), 2) >>> res = gatconv(g, (u_feat, v_feat)) >>> res tensor([[ 0.4718, 0.0864], [ 0.7099, -0.0335], [ 0.5869, 0.0284], [ 0.4718, 0.0864]], grad_fn=<CopyReduceBackward>)
-
forward
(graph, feat)[source]¶ Apply dot product version of self attention in GCN.
- Parameters
graph (DGLGraph or bi_partities graph) – The graph
feat (torch.Tensor or pair of torch.Tensor) – If a torch.Tensor is given, the input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape \((N_{in}, D_{in_{src}})\) and \((N_{out}, D_{in_{dst}})\).
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
- Raises
DGLError – If there are 0-in-degree nodes in the input graph, it will raise DGLError since no message will be passed to those nodes. This will cause invalid output. The error can be ignored by setting
allow_zero_in_degree
parameter toTrue
.
Dense Conv Layers¶
DenseGraphConv¶
-
class
dgl.nn.pytorch.conv.
DenseGraphConv
(in_feats, out_feats, norm='both', bias=True, activation=None)[source]¶ Bases:
torch.nn.modules.module.Module
Graph Convolutional Network layer where the graph structure is given by an adjacency matrix. We recommend user to use this module when applying graph convolution on dense graphs.
- Parameters
in_feats (int) – Input feature size; i.e, the number of dimensions of \(h_j^{(l)}\).
out_feats (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(l+1)}\).
norm (str, optional) – How to apply the normalizer. If is ‘right’, divide the aggregated messages by each node’s in-degrees, which is equivalent to averaging the received messages. If is ‘none’, no normalization is applied. Default is ‘both’, where the \(c_{ij}\) in the paper is applied.
bias (bool, optional) – If True, adds a learnable bias to the output. Default:
True
.activation (callable activation function/layer or None, optional) – If not None, applies an activation function to the updated node features. Default:
None
.
Notes
Zero in-degree nodes will lead to all-zero output. A common practice to avoid this is to add a self-loop for each node in the graph, which can be achieved by setting the diagonal of the adjacency matrix to be 1.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import DenseGraphConv >>> >>> feat = th.ones(6, 10) >>> adj = th.tensor([[0., 0., 1., 0., 0., 0.], ... [1., 0., 0., 0., 0., 0.], ... [0., 1., 0., 0., 0., 0.], ... [0., 0., 1., 0., 0., 1.], ... [0., 0., 0., 1., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) >>> conv = DenseGraphConv(10, 2) >>> res = conv(adj, feat) >>> res tensor([[0.2159, 1.9027], [0.3053, 2.6908], [0.3053, 2.6908], [0.3685, 3.2481], [0.3053, 2.6908], [0.0000, 0.0000]], grad_fn=<AddBackward0>)
See also
-
forward
(adj, feat)[source]¶ Compute (Dense) Graph Convolution layer.
- Parameters
adj (torch.Tensor) – The adjacency matrix of the graph to apply Graph Convolution on, when applied to a unidirectional bipartite graph,
adj
should be of shape should be of shape \((N_{out}, N_{in})\); when applied to a homo graph,adj
should be of shape \((N, N)\). In both cases, a row represents a destination node while a column represents a source node.feat (torch.Tensor) – The input feature.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
DenseSAGEConv¶
-
class
dgl.nn.pytorch.conv.
DenseSAGEConv
(in_feats, out_feats, feat_drop=0.0, bias=True, norm=None, activation=None)[source]¶ Bases:
torch.nn.modules.module.Module
GraphSAGE layer where the graph structure is given by an adjacency matrix. We recommend to use this module when appying GraphSAGE on dense graphs.
Note that we only support gcn aggregator in DenseSAGEConv.
- Parameters
in_feats (int) – Input feature size; i.e, the number of dimensions of \(h_i^{(l)}\).
out_feats (int) – Output feature size; i.e, the number of dimensions of \(h_i^{(l+1)}\).
feat_drop (float, optional) – Dropout rate on features. Default: 0.
bias (bool) – If True, adds a learnable bias to the output. Default:
True
.norm (callable activation function/layer or None, optional) – If not None, applies normalization to the updated node features.
activation (callable activation function/layer or None, optional) – If not None, applies an activation function to the updated node features. Default:
None
.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import DenseSAGEConv >>> >>> feat = th.ones(6, 10) >>> adj = th.tensor([[0., 0., 1., 0., 0., 0.], ... [1., 0., 0., 0., 0., 0.], ... [0., 1., 0., 0., 0., 0.], ... [0., 0., 1., 0., 0., 1.], ... [0., 0., 0., 1., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) >>> conv = DenseSAGEConv(10, 2) >>> res = conv(adj, feat) >>> res tensor([[1.0401, 2.1008], [1.0401, 2.1008], [1.0401, 2.1008], [1.0401, 2.1008], [1.0401, 2.1008], [1.0401, 2.1008]], grad_fn=<AddmmBackward>)
See also
-
forward
(adj, feat)[source]¶ Compute (Dense) Graph SAGE layer.
- Parameters
adj (torch.Tensor) – The adjacency matrix of the graph to apply SAGE Convolution on, when applied to a unidirectional bipartite graph,
adj
should be of shape should be of shape \((N_{out}, N_{in})\); when applied to a homo graph,adj
should be of shape \((N, N)\). In both cases, a row represents a destination node while a column represents a source node.feat (torch.Tensor or a pair of torch.Tensor) – If a torch.Tensor is given, the input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape \((N_{in}, D_{in})\) and \((N_{out}, D_{in})\).
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
DenseChebConv¶
-
class
dgl.nn.pytorch.conv.
DenseChebConv
(in_feats, out_feats, k, bias=True)[source]¶ Bases:
torch.nn.modules.module.Module
Chebyshev Spectral Graph Convolution layer from paper Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering.
We recommend to use this module when applying ChebConv on dense graphs.
- Parameters
in_feats (int) – Dimension of input features \(h_i^{(l)}\).
out_feats (int) – Dimension of output features \(h_i^{(l+1)}\).
k (int) – Chebyshev filter size.
activation (function, optional) – Activation function, default is ReLu.
bias (bool, optional) – If True, adds a learnable bias to the output. Default:
True
.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import DenseChebConv >>> >>> feat = th.ones(6, 10) >>> adj = th.tensor([[0., 0., 1., 0., 0., 0.], ... [1., 0., 0., 0., 0., 0.], ... [0., 1., 0., 0., 0., 0.], ... [0., 0., 1., 0., 0., 1.], ... [0., 0., 0., 1., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) >>> conv = DenseChebConv(10, 2, 2) >>> res = conv(adj, feat) >>> res tensor([[-3.3516, -2.4797], [-3.3516, -2.4797], [-3.3516, -2.4797], [-4.5192, -3.0835], [-2.5259, -2.0527], [-0.5327, -1.0219]], grad_fn=<AddBackward0>)
See also
-
forward
(adj, feat, lambda_max=None)[source]¶ Compute (Dense) Chebyshev Spectral Graph Convolution layer.
- Parameters
adj (torch.Tensor) – The adjacency matrix of the graph to apply Graph Convolution on, should be of shape \((N, N)\), where a row represents the destination and a column represents the source.
feat (torch.Tensor) – The input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes.
lambda_max (float or None, optional) – A float value indicates the largest eigenvalue of given graph. Default: None.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor
Global Pooling Layers¶
Torch modules for graph global pooling.
SumPooling¶
-
class
dgl.nn.pytorch.glob.
SumPooling
[source]¶ Bases:
torch.nn.modules.module.Module
Apply sum pooling over the nodes in a graph .
\[r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k\]Notes
Input: Could be one graph, or a batch of graphs. If using a batch of graphs, make sure nodes in all graphs have the same feature size, and concatenate nodes’ feature together as the input.
Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch as th >>> from dgl.nn.pytorch.glob import SumPooling >>> >>> g1 = dgl.DGLGraph() >>> g1.add_nodes(2) >>> g1_node_feats = th.ones(2,5) >>> >>> g2 = dgl.DGLGraph() >>> g2.add_nodes(3) >>> g2_node_feats = th.ones(3,5) >>> >>> sumpool = SumPooling()
Case 1: Input a single graph
>>> sumpool(g1, g1_node_feats) tensor([[2., 2., 2., 2., 2.]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.
>>> batch_g = dgl.batch([g1, g2]) >>> batch_f = th.cat([g1_node_feats, g2_node_feats]) >>> >>> sumpool(batch_g, batch_f) tensor([[2., 2., 2., 2., 2.], [3., 3., 3., 3., 3.]])
-
forward
(graph, feat)[source]¶ Compute sum pooling.
- Parameters
graph (DGLGraph) – a DGLGraph or a batch of DGLGraphs
feat (torch.Tensor) – The input feature with shape \((N, D)\), where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.
- Returns
The output feature with shape \((B, D)\), where \(B\) refers to the batch size of input graphs.
- Return type
torch.Tensor
-
AvgPooling¶
-
class
dgl.nn.pytorch.glob.
AvgPooling
[source]¶ Bases:
torch.nn.modules.module.Module
Apply average pooling over the nodes in a graph.
\[r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k\]Notes
Input: Could be one graph, or a batch of graphs. If using a batch of graphs, make sure nodes in all graphs have the same feature size, and concatenate nodes’ feature together as the input.
Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch as th >>> from dgl.nn.pytorch.glob import AvgPooling >>> >>> g1 = dgl.DGLGraph() >>> g1.add_nodes(2) >>> g1_node_feats = th.ones(2,5) >>> >>> g2 = dgl.DGLGraph() >>> g2.add_nodes(3) >>> g2_node_feats = th.ones(3,5) >>> >>> avgpool = AvgPooling()
Case 1: Input single graph
>>> avgpool(g1, g1_node_feats) tensor([[1., 1., 1., 1., 1.]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs’ note features into one tensor.
>>> batch_g = dgl.batch([g1, g2]) >>> batch_f = th.cat([g1_node_feats, g2_node_feats]) >>> >>> avgpool(batch_g, batch_f) tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]])
-
forward
(graph, feat)[source]¶ Compute average pooling.
- Parameters
graph (DGLGraph) – A DGLGraph or a batch of DGLGraphs.
feat (torch.Tensor) – The input feature with shape \((N, D)\), where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.
- Returns
The output feature with shape \((B, D)\), where \(B\) refers to the batch size of input graphs.
- Return type
torch.Tensor
-
MaxPooling¶
-
class
dgl.nn.pytorch.glob.
MaxPooling
[source]¶ Bases:
torch.nn.modules.module.Module
Apply max pooling over the nodes in a graph.
\[r^{(i)} = \max_{k=1}^{N_i}\left( x^{(i)}_k \right)\]Notes
Input: Could be one graph, or a batch of graphs. If using a batch of graphs, make sure nodes in all graphs have the same feature size, and concatenate nodes’ feature together as the input.
Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch as th >>> from dgl.nn.pytorch.glob import MaxPooling >>> >>> g1 = dgl.DGLGraph() >>> g1.add_nodes(2) >>> g1_node_feats = th.ones(2,5) >>> >>> g2 = dgl.DGLGraph() >>> g2.add_nodes(3) >>> g2_node_feats = th.ones(3,5) >>> >>> maxpool = MaxPooling()
Case 1: Input a single graph
>>> maxpool(g1, g1_node_feats) tensor([[1., 1., 1., 1., 1.]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.
>>> batch_g = dgl.batch([g1, g2]) >>> batch_f = th.cat([g1_node_feats, g2_node_feats]) >>> >>> maxpool(batch_g, batch_f) tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]])
-
forward
(graph, feat)[source]¶ Compute max pooling.
- Parameters
graph (DGLGraph) – A DGLGraph or a batch of DGLGraphs.
feat (torch.Tensor) – The input feature with shape \((N, *)\), where \(N\) is the number of nodes in the graph.
- Returns
The output feature with shape \((B, *)\), where \(B\) refers to the batch size.
- Return type
torch.Tensor
-
SortPooling¶
-
class
dgl.nn.pytorch.glob.
SortPooling
(k)[source]¶ Bases:
torch.nn.modules.module.Module
Apply Sort Pooling (An End-to-End Deep Learning Architecture for Graph Classification) over the nodes in a graph.
- Parameters
k (int) – The number of nodes to hold for each graph.
Notes
Input: Could be one graph, or a batch of graphs. If using a batch of graphs, make sure nodes in all graphs have the same feature size, and concatenate nodes’ feature together as the input.
Examples
>>> import dgl >>> import torch as th >>> from dgl.nn.pytorch.glob import SortPooling >>> >>> g1 = dgl.DGLGraph() >>> g1.add_nodes(2) >>> g1_node_feats = th.ones(2,5) >>> >>> g2 = dgl.DGLGraph() >>> g2.add_nodes(3) >>> g2_node_feats = th.ones(3,5) >>> >>> sortpool = SortPooling(k=2)
Case 1: Input a single graph
>>> sortpool(g1, g1_node_feats) tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
Case 2: Input a batch of graphs
Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.
>>> batch_g = dgl.batch([g1, g2]) >>> batch_f = th.cat([g1_node_feats, g2_node_feats]) >>> >>> sortpool(batch_g, batch_f) tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
-
forward
(graph, feat)[source]¶ Compute sort pooling.
- Parameters
graph (DGLGraph) – A DGLGraph or a batch of DGLGraphs.
feat (torch.Tensor) – The input feature with shape \((N, D)\), where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.
- Returns
The output feature with shape \((B, k * D)\), where \(B\) refers to the batch size of input graphs.
- Return type
torch.Tensor
GlobalAttentionPooling¶
-
class
dgl.nn.pytorch.glob.
GlobalAttentionPooling
(gate_nn, feat_nn=None)[source]¶ Bases:
torch.nn.modules.module.Module
Apply Global Attention Pooling (Gated Graph Sequence Neural Networks) over the nodes in a graph.
\[r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate} \left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)\]- Parameters
gate_nn (torch.nn.Module) – A neural network that computes attention scores for each feature.
feat_nn (torch.nn.Module, optional) – A neural network applied to each feature before combining them with attention scores.
-
forward
(graph, feat)[source]¶ Compute global attention pooling.
- Parameters
graph (DGLGraph) – A DGLGraph or a batch of DGLGraphs.
feat (torch.Tensor) – The input feature with shape \((N, D)\) where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.
- Returns
The output feature with shape \((B, D)\), where \(B\) refers to the batch size.
- Return type
torch.Tensor
Set2Set¶
-
class
dgl.nn.pytorch.glob.
Set2Set
(input_dim, n_iters, n_layers)[source]¶ Bases:
torch.nn.modules.module.Module
For each individual graph in the batch, set2set computes
\[ \begin{align}\begin{aligned}q_t &= \mathrm{LSTM} (q^*_{t-1})\\\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)\\r_t &= \sum_{i=1}^N \alpha_{i,t} x_i\\q^*_t &= q_t \Vert r_t\end{aligned}\end{align} \]for this graph.
- Parameters
-
forward
(graph, feat)[source]¶ Compute set2set pooling.
- Parameters
graph (DGLGraph) – The input graph.
feat (torch.Tensor) – The input feature with shape \((N, D)\) where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.
- Returns
The output feature with shape \((B, D)\), where \(B\) refers to the batch size, and \(D\) means the size of features.
- Return type
torch.Tensor
SetTransformerEncoder¶
-
class
dgl.nn.pytorch.glob.
SetTransformerEncoder
(d_model, n_heads, d_head, d_ff, n_layers=1, block_type='sab', m=None, dropouth=0.0, dropouta=0.0)[source]¶ Bases:
torch.nn.modules.module.Module
The Encoder module in Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks.
- Parameters
d_model (int) – The hidden size of the model.
n_heads (int) – The number of heads.
d_head (int) – The hidden size of each head.
d_ff (int) – The kernel size in FFN (Positionwise Feed-Forward Network) layer.
n_layers (int) – The number of layers.
block_type (str) – Building block type: ‘sab’ (Set Attention Block) or ‘isab’ (Induced Set Attention Block).
m (int or None) – The number of induced vectors in ISAB Block. Set to None if block type is ‘sab’.
dropouth (float) – The dropout rate of each sublayer.
dropouta (float) – The dropout rate of attention heads.
-
forward
(graph, feat)[source]¶ Compute the Encoder part of Set Transformer.
- Parameters
graph (DGLGraph) – The input graph.
feat (torch.Tensor) – The input feature with shape \((N, D)\), where \(N\) is the number of nodes in the graph.
- Returns
The output feature with shape \((N, D)\).
- Return type
torch.Tensor
SetTransformerDecoder¶
-
class
dgl.nn.pytorch.glob.
SetTransformerDecoder
(d_model, num_heads, d_head, d_ff, n_layers, k, dropouth=0.0, dropouta=0.0)[source]¶ Bases:
torch.nn.modules.module.Module
The Decoder module in Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks.
- Parameters
d_model (int) – Hidden size of the model.
num_heads (int) – The number of heads.
d_head (int) – Hidden size of each head.
d_ff (int) – Kernel size in FFN (Positionwise Feed-Forward Network) layer.
n_layers (int) – The number of layers.
k (int) – The number of seed vectors in PMA (Pooling by Multihead Attention) layer.
dropouth (float) – Dropout rate of each sublayer.
dropouta (float) – Dropout rate of attention heads.
-
forward
(graph, feat)[source]¶ Compute the decoder part of Set Transformer.
- Parameters
graph (DGLGraph) – The input graph.
feat (torch.Tensor) – The input feature with shape \((N, D)\), where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.
- Returns
The output feature with shape \((B, D)\), where \(B\) refers to the batch size.
- Return type
torch.Tensor
Heterogeneous Graph Convolution Module¶
HeteroGraphConv¶
-
class
dgl.nn.pytorch.
HeteroGraphConv
(mods, aggregate='sum')[source]¶ Bases:
torch.nn.modules.module.Module
A generic module for computing convolution on heterogeneous graphs.
The heterograph convolution applies sub-modules on their associating relation graphs, which reads the features from source nodes and writes the updated ones to destination nodes. If multiple relations have the same destination node types, their results are aggregated by the specified method.
If the relation graph has no edge, the corresponding module will not be called.
Examples
Create a heterograph with three types of relations and nodes.
>>> import dgl >>> g = dgl.heterograph({ ... ('user', 'follows', 'user') : edges1, ... ('user', 'plays', 'game') : edges2, ... ('store', 'sells', 'game') : edges3})
Create a
HeteroGraphConv
that applies different convolution modules to different relations. Note that the modules for'follows'
and'plays'
do not share weights.>>> import dgl.nn.pytorch as dglnn >>> conv = dglnn.HeteroGraphConv({ ... 'follows' : dglnn.GraphConv(...), ... 'plays' : dglnn.GraphConv(...), ... 'sells' : dglnn.SAGEConv(...)}, ... aggregate='sum')
Call forward with some
'user'
features. This computes new features for both'user'
and'game'
nodes.>>> import torch as th >>> h1 = {'user' : th.randn((g.number_of_nodes('user'), 5))} >>> h2 = conv(g, h1) >>> print(h2.keys()) dict_keys(['user', 'game'])
Call forward with both
'user'
and'store'
features. Because both the'plays'
and'sells'
relations will update the'game'
features, their results are aggregated by the specified method (i.e., summation here).>>> f1 = {'user' : ..., 'store' : ...} >>> f2 = conv(g, f1) >>> print(f2.keys()) dict_keys(['user', 'game'])
Call forward with some
'store'
features. This only computes new features for'game'
nodes.>>> g1 = {'store' : ...} >>> g2 = conv(g, g1) >>> print(g2.keys()) dict_keys(['game'])
Call forward with a pair of inputs is allowed and each submodule will also be invoked with a pair of inputs.
>>> x_src = {'user' : ..., 'store' : ...} >>> x_dst = {'user' : ..., 'game' : ...} >>> y_dst = conv(g, (x_src, x_dst)) >>> print(y_dst.keys()) dict_keys(['user', 'game'])
- Parameters
mods (dict[str, nn.Module]) – Modules associated with every edge types. The forward function of each module must have a DGLHeteroGraph object as the first argument, and its second argument is either a tensor object representing the node features or a pair of tensor object representing the source and destination node features.
aggregate (str, callable, optional) –
Method for aggregating node features generated by different relations. Allowed string values are ‘sum’, ‘max’, ‘min’, ‘mean’, ‘stack’. The ‘stack’ aggregation is performed along the second dimension, whose order is deterministic. User can also customize the aggregator by providing a callable instance. For example, aggregation by summation is equivalent to the follows:
def my_agg_func(tensors, dsttype): # tensors: is a list of tensors to aggregate # dsttype: string name of the destination node type for which the # aggregation is performed stacked = torch.stack(tensors, dim=0) return torch.sum(stacked, dim=0)
-
forward
(g, inputs, mod_args=None, mod_kwargs=None)[source]¶ Forward computation
Invoke the forward function with each module and aggregate their results.
- Parameters
g (DGLHeteroGraph) – Graph data.
inputs (dict[str, Tensor] or pair of dict[str, Tensor]) – Input node features.
mod_args (dict[str, tuple[any]], optional) – Extra positional arguments for the sub-modules.
mod_kwargs (dict[str, dict[str, any]], optional) – Extra key-word arguments for the sub-modules.
- Returns
Output representations for every types of nodes.
- Return type
Utility Modules¶
Sequential¶
-
class
dgl.nn.pytorch.utils.
Sequential
(*args)[source]¶ Bases:
torch.nn.modules.container.Sequential
A squential container for stacking graph neural network modules.
DGL supports two modes: sequentially apply GNN modules on 1) the same graph or 2) a list of given graphs. In the second case, the number of graphs equals the number of modules inside this container.
- Parameters
*args – Sub-modules of torch.nn.Module that will be added to the container in the order by which they are passed in the constructor.
Examples
The following example uses PyTorch backend.
Mode 1: sequentially apply GNN modules on the same graph
>>> import torch >>> import dgl >>> import torch.nn as nn >>> import dgl.function as fn >>> from dgl.nn.pytorch import Sequential >>> class ExampleLayer(nn.Module): >>> def __init__(self): >>> super().__init__() >>> def forward(self, graph, n_feat, e_feat): >>> with graph.local_scope(): >>> graph.ndata['h'] = n_feat >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> n_feat += graph.ndata['h'] >>> graph.apply_edges(fn.u_add_v('h', 'h', 'e')) >>> e_feat += graph.edata['e'] >>> return n_feat, e_feat >>> >>> g = dgl.DGLGraph() >>> g.add_nodes(3) >>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2]) >>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer()) >>> n_feat = torch.rand(3, 4) >>> e_feat = torch.rand(9, 4) >>> net(g, n_feat, e_feat) (tensor([[39.8597, 45.4542, 25.1877, 30.8086], [40.7095, 45.3985, 25.4590, 30.0134], [40.7894, 45.2556, 25.5221, 30.4220]]), tensor([[80.3772, 89.7752, 50.7762, 60.5520], [80.5671, 89.3736, 50.6558, 60.6418], [80.4620, 89.5142, 50.3643, 60.3126], [80.4817, 89.8549, 50.9430, 59.9108], [80.2284, 89.6954, 50.0448, 60.1139], [79.7846, 89.6882, 50.5097, 60.6213], [80.2654, 90.2330, 50.2787, 60.6937], [80.3468, 90.0341, 50.2062, 60.2659], [80.0556, 90.2789, 50.2882, 60.5845]]))
Mode 2: sequentially apply GNN modules on different graphs
>>> import torch >>> import dgl >>> import torch.nn as nn >>> import dgl.function as fn >>> import networkx as nx >>> from dgl.nn.pytorch import Sequential >>> class ExampleLayer(nn.Module): >>> def __init__(self): >>> super().__init__() >>> def forward(self, graph, n_feat): >>> with graph.local_scope(): >>> graph.ndata['h'] = n_feat >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> n_feat += graph.ndata['h'] >>> return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1) >>> >>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)) >>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)) >>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)) >>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer()) >>> n_feat = torch.rand(32, 4) >>> net([g1, g2, g3], n_feat) tensor([[209.6221, 225.5312, 193.8920, 220.1002], [250.0169, 271.9156, 240.2467, 267.7766], [220.4007, 239.7365, 213.8648, 234.9637], [196.4630, 207.6319, 184.2927, 208.7465]])
KNNGraph¶
-
class
dgl.nn.pytorch.factory.
KNNGraph
(k)[source]¶ Bases:
torch.nn.modules.module.Module
Layer that transforms one point set into a graph, or a batch of point sets with the same number of points into a union of those graphs.
The KNNGraph is implemented in the following steps:
Compute an NxN matrix of pairwise distance for all points.
Pick the k points with the smallest distance for each point as their k-nearest neighbors.
Construct a graph with edges to each point as a node from its k-nearest neighbors.
The overall computational complexity is \(O(N^2(logN + D)\).
If a batch of point sets is provided, the point \(j\) in point set \(i\) is mapped to graph node ID: \(i \times M + j\), where \(M\) is the number of nodes in each point set.
The predecessors of each node are the k-nearest neighbors of the corresponding point.
- Parameters
k (int) – The number of neighbors.
Notes
The nearest neighbors found for a node include the node itself.
Examples
The following example uses PyTorch backend.
>>> import torch >>> from dgl.nn.pytorch.factory import KNNGraph >>> >>> kg = KNNGraph(2) >>> x = torch.tensor([[0,1], [1,2], [1,3], [100, 101], [101, 102], [50, 50]]) >>> g = kg(x) >>> print(g.edges()) (tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]), tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
SegmentedKNNGraph¶
-
class
dgl.nn.pytorch.factory.
SegmentedKNNGraph
(k)[source]¶ Bases:
torch.nn.modules.module.Module
Layer that transforms one point set into a graph, or a batch of point sets with different number of points into a union of those graphs.
If a batch of point sets is provided, then the point \(j\) in the point set \(i\) is mapped to graph node ID: \(\sum_{p<i} |V_p| + j\), where \(|V_p|\) means the number of points in the point set \(p\).
The predecessors of each node are the k-nearest neighbors of the corresponding point.
- Parameters
k (int) – The number of neighbors.
Notes
The nearest neighbors found for a node include the node itself.
Examples
The following example uses PyTorch backend.
>>> import torch >>> from dgl.nn.pytorch.factory import SegmentedKNNGraph >>> >>> kg = SegmentedKNNGraph(2) >>> x = torch.tensor([[0,1], ... [1,2], ... [1,3], ... [100, 101], ... [101, 102], ... [50, 50], ... [24,25], ... [25,24]]) >>> g = kg(x, [3,3,2]) >>> print(g.edges()) (tensor([0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7]), tensor([0, 0, 1, 2, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 6, 7])) >>>
-
forward
(x, segs)[source]¶ Forward computation.
- Parameters
x (Tensor) – \((M, D)\) where \(M\) means the total number of points in all point sets, and \(D\) means the size of features.
segs (iterable of int) – \((N)\) integers where \(N\) means the number of point sets. The number of elements must sum up to \(M\). And any \(N\) should \(\ge k\)
- Returns
A DGLGraph without features.
- Return type
DGLGraph