GatedGraphConv¶
-
class
dgl.nn.pytorch.conv.
GatedGraphConv
(in_feats, out_feats, n_steps, n_etypes, bias=True)[source]¶ Bases:
torch.nn.modules.module.Module
Gated Graph Convolution layer from Gated Graph Sequence Neural Networks
\[ \begin{align}\begin{aligned}h_{i}^{0} &= [ x_i \| \mathbf{0} ]\\a_{i}^{t} &= \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t}\\h_{i}^{t+1} &= \mathrm{GRU}(a_{i}^{t}, h_{i}^{t})\end{aligned}\end{align} \]- Parameters
in_feats (int) – Input feature size; i.e, the number of dimensions of \(x_i\).
out_feats (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(t+1)}\).
n_steps (int) – Number of recurrent steps; i.e, the \(t\) in the above formula.
n_etypes (int) – Number of edge types.
bias (bool) – If True, adds a learnable bias to the output. Default:
True
.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import GatedGraphConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = GatedGraphConv(10, 10, 2, 3) >>> etype = th.tensor([0,1,2,0,1,2]) >>> res = conv(g, feat, etype) >>> res tensor([[ 0.4652, 0.4458, 0.5169, 0.4126, 0.4847, 0.2303, 0.2757, 0.7721, 0.0523, 0.0857], [ 0.0832, 0.1388, -0.5643, 0.7053, -0.2524, -0.3847, 0.7587, 0.8245, 0.9315, 0.4063], [ 0.6340, 0.4096, 0.7692, 0.2125, 0.2106, 0.4542, -0.0580, 0.3364, -0.1376, 0.4948], [ 0.5551, 0.7946, 0.6220, 0.8058, 0.5711, 0.3063, -0.5454, 0.2272, -0.6931, -0.1607], [ 0.2644, 0.2469, -0.6143, 0.6008, -0.1516, -0.3781, 0.5878, 0.7993, 0.9241, 0.1835], [ 0.6393, 0.3447, 0.3893, 0.4279, 0.3342, 0.3809, 0.0406, 0.5030, 0.1342, 0.0425]], grad_fn=<AddBackward0>)
-
forward
(graph, feat, etypes=None)[source]¶ Compute Gated Graph Convolution layer.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – The input feature of shape \((N, D_{in})\) where \(N\) is the number of nodes of the graph and \(D_{in}\) is the input feature size.
etypes (torch.LongTensor, or None) – The edge type tensor of shape \((E,)\) where \(E\) is the number of edges of the graph. When there’s only one edge type, this argument can be skipped
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is the output feature size.
- Return type
torch.Tensor