GatedGraphConv

class dgl.nn.pytorch.conv.GatedGraphConv(in_feats, out_feats, n_steps, n_etypes, bias=True)[source]

Bases: torch.nn.modules.module.Module

Gated Graph Convolution layer from Gated Graph Sequence Neural Networks

\[ \begin{align}\begin{aligned}h_{i}^{0} &= [ x_i \| \mathbf{0} ]\\a_{i}^{t} &= \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t}\\h_{i}^{t+1} &= \mathrm{GRU}(a_{i}^{t}, h_{i}^{t})\end{aligned}\end{align} \]
Parameters
  • in_feats (int) – Input feature size; i.e, the number of dimensions of \(x_i\).

  • out_feats (int) – Output feature size; i.e., the number of dimensions of \(h_i^{(t+1)}\).

  • n_steps (int) – Number of recurrent steps; i.e, the \(t\) in the above formula.

  • n_etypes (int) – Number of edge types.

  • bias (bool) – If True, adds a learnable bias to the output. Default: True.

Example

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import GatedGraphConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = GatedGraphConv(10, 10, 2, 3)
>>> etype = th.tensor([0,1,2,0,1,2])
>>> res = conv(g, feat, etype)
>>> res
tensor([[ 0.4652,  0.4458,  0.5169,  0.4126,  0.4847,  0.2303,  0.2757,  0.7721,
        0.0523,  0.0857],
        [ 0.0832,  0.1388, -0.5643,  0.7053, -0.2524, -0.3847,  0.7587,  0.8245,
        0.9315,  0.4063],
        [ 0.6340,  0.4096,  0.7692,  0.2125,  0.2106,  0.4542, -0.0580,  0.3364,
        -0.1376,  0.4948],
        [ 0.5551,  0.7946,  0.6220,  0.8058,  0.5711,  0.3063, -0.5454,  0.2272,
        -0.6931, -0.1607],
        [ 0.2644,  0.2469, -0.6143,  0.6008, -0.1516, -0.3781,  0.5878,  0.7993,
        0.9241,  0.1835],
        [ 0.6393,  0.3447,  0.3893,  0.4279,  0.3342,  0.3809,  0.0406,  0.5030,
        0.1342,  0.0425]], grad_fn=<AddBackward0>)
forward(graph, feat, etypes=None)[source]

Compute Gated Graph Convolution layer.

Parameters
  • graph (DGLGraph) – The graph.

  • feat (torch.Tensor) – The input feature of shape \((N, D_{in})\) where \(N\) is the number of nodes of the graph and \(D_{in}\) is the input feature size.

  • etypes (torch.LongTensor, or None) – The edge type tensor of shape \((E,)\) where \(E\) is the number of edges of the graph. When there’s only one edge type, this argument can be skipped

Returns

The output feature of shape \((N, D_{out})\) where \(D_{out}\) is the output feature size.

Return type

torch.Tensor

reset_parameters()[source]

Reinitialize learnable parameters.

Note

The model parameters are initialized using Glorot uniform initialization and the bias is initialized to be zero.