NNConvο
- class dgl.nn.pytorch.conv.NNConv(in_feats, out_feats, edge_func, aggregator_type='mean', residual=False, bias=True)[source]ο
Bases:
Module
Graph Convolution layer from Neural Message Passing for Quantum Chemistry
\[h_{i}^{l+1} = h_{i}^{l} + \mathrm{aggregate}\left(\left\{ f_\Theta (e_{ij}) \cdot h_j^{l}, j\in \mathcal{N}(i) \right\}\right)\]where \(e_{ij}\) is the edge feature, \(f_\Theta\) is a function with learnable parameters.
- Parameters:
in_feats (int) β Input feature size; i.e, the number of dimensions of \(h_j^{(l)}\). NNConv can be applied on homogeneous graph and unidirectional bipartite graph. If the layer is to be applied on a unidirectional bipartite graph,
in_feats
specifies the input feature size on both the source and destination nodes. If a scalar is given, the source and destination node feature size would take the same value.out_feats (int) β Output feature size; i.e., the number of dimensions of \(h_i^{(l+1)}\).
edge_func (callable activation function/layer) β Maps each edge feature to a vector of shape
(in_feats * out_feats)
as weight to compute messages. Also is the \(f_\Theta\) in the formula.aggregator_type (str) β Aggregator type to use (
sum
,mean
ormax
).residual (bool, optional) β If True, use residual connection. Default:
False
.bias (bool, optional) β If True, adds a learnable bias to the output. Default:
True
.
Examples
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import NNConv
>>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> lin = th.nn.Linear(5, 20) >>> def edge_func(efeat): ... return lin(efeat) >>> efeat = th.ones(6+6, 5) >>> conv = NNConv(10, 2, edge_func, 'mean') >>> res = conv(g, feat, efeat) >>> res tensor([[-1.5243, -0.2719], [-1.5243, -0.2719], [-1.5243, -0.2719], [-1.5243, -0.2719], [-1.5243, -0.2719], [-1.5243, -0.2719]], grad_fn=<AddBackward0>)
>>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.heterograph({('_N', '_E', '_N'):(u, v)}) >>> u_feat = th.tensor(np.random.rand(2, 10).astype(np.float32)) >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32)) >>> conv = NNConv(10, 2, edge_func, 'mean') >>> efeat = th.ones(5, 5) >>> res = conv(g, (u_feat, v_feat), efeat) >>> res tensor([[-0.6568, 0.5042], [ 0.9089, -0.5352], [ 0.1261, -0.0155], [-0.6568, 0.5042]], grad_fn=<AddBackward0>)
- forward(graph, feat, efeat)[source]ο
Compute MPNN Graph Convolution layer.
- Parameters:
graph (DGLGraph) β The graph.
feat (torch.Tensor or pair of torch.Tensor) β The input feature of shape \((N, D_{in})\) where \(N\) is the number of nodes of the graph and \(D_{in}\) is the input feature size.
efeat (torch.Tensor) β The edge feature of shape \((E, *)\), which should fit the input shape requirement of
edge_func
. \(E\) is the number of edges of the graph.
- Returns:
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is the output feature size.
- Return type:
torch.Tensor