EGNNConvο
- class dgl.nn.pytorch.conv.EGNNConv(in_size, hidden_size, out_size, edge_feat_size=0)[source]ο
Bases:
Module
Equivariant Graph Convolutional Layer from E(n) Equivariant Graph Neural Networks
\[ \begin{align}\begin{aligned}m_{ij}=\phi_e(h_i^l, h_j^l, ||x_i^l-x_j^l||^2, a_{ij})\\x_i^{l+1} = x_i^l + C\sum_{j\in\mathcal{N}(i)}(x_i^l-x_j^l)\phi_x(m_{ij})\\m_i = \sum_{j\in\mathcal{N}(i)} m_{ij}\\h_i^{l+1} = \phi_h(h_i^l, m_i)\end{aligned}\end{align} \]where \(h_i\), \(x_i\), \(a_{ij}\) are node features, coordinate features, and edge features respectively. \(\phi_e\), \(\phi_h\), and \(\phi_x\) are two-layer MLPs. \(C\) is a constant for normalization, computed as \(1/|\mathcal{N}(i)|\).
- Parameters:
in_size (int) β Input feature size; i.e. the size of \(h_i^l\).
hidden_size (int) β Hidden feature size; i.e. the size of hidden layer in the two-layer MLPs in \(\phi_e, \phi_x, \phi_h\).
out_size (int) β Output feature size; i.e. the size of \(h_i^{l+1}\).
edge_feat_size (int, optional) β Edge feature size; i.e. the size of \(a_{ij}\). Default: 0.
Example
>>> import dgl >>> import torch as th >>> from dgl.nn import EGNNConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> node_feat, coord_feat, edge_feat = th.ones(6, 10), th.ones(6, 3), th.ones(6, 2) >>> conv = EGNNConv(10, 10, 10, 2) >>> h, x = conv(g, node_feat, coord_feat, edge_feat)
- forward(graph, node_feat, coord_feat, edge_feat=None)[source]ο
Descriptionο
Compute EGNN layer.
- param graph:
The graph.
- type graph:
DGLGraph
- param node_feat:
The input feature of shape \((N, h_n)\). \(N\) is the number of nodes, and \(h_n\) must be the same as in_size.
- type node_feat:
torch.Tensor
- param coord_feat:
The coordinate feature of shape \((N, h_x)\). \(N\) is the number of nodes, and \(h_x\) can be any positive integer.
- type coord_feat:
torch.Tensor
- param edge_feat:
The edge feature of shape \((M, h_e)\). \(M\) is the number of edges, and \(h_e\) must be the same as edge_feat_size.
- type edge_feat:
torch.Tensor, optional
- returns:
node_feat_out (torch.Tensor) β The output node feature of shape \((N, h_n')\) where \(h_n'\) is the same as out_size.
coord_feat_out (torch.Tensor) β The output coordinate feature of shape \((N, h_x)\) where \(h_x\) is the same as the input coordinate feature dimension.