TWIRLSConvο
- class dgl.nn.pytorch.conv.TWIRLSConv(input_d, output_d, hidden_d, prop_step, num_mlp_before=1, num_mlp_after=1, norm='none', precond=True, alp=0, lam=1, attention=False, tau=0.2, T=-1, p=1, use_eta=False, attn_bef=False, dropout=0.0, attn_dropout=0.0, inp_dropout=0.0)[source]ο
Bases:
Module
Convolution together with iteratively reweighting least squre from Graph Neural Networks Inspired by Classical Iterative Algorithms
- Parameters:
input_d (int) β Number of input features.
output_d (int) β Number of output features.
hidden_d (int) β Size of hidden layers.
prop_step (int) β Number of propagation steps
num_mlp_before (int) β Number of mlp layers before propagation. Default:
1
.num_mlp_after (int) β Number of mlp layers after propagation. Default:
1
.norm (str) β The type of norm layers inside mlp layers. Can be
'batch'
,'layer'
or'none'
. Default:'none'
precond (str) β If True, use pre conditioning and unormalized laplacian, else not use pre conditioning and use normalized laplacian. Default:
True
alp (float) β The \(\alpha\) in paper. If equal to \(0\), will be automatically decided based on other hyper prameters. Default:
0
.lam (float) β The \(\lambda\) in paper. Default:
1
.attention (bool) β If
True
, add an attention layer inside propagations. Default:False
.tau (float) β The \(\tau\) in paper. Default:
0.2
.T (float) β The \(T\) in paper. If < 0, \(T\) will be set to infty. Default:
-1
.p (float) β The \(p\) in paper. Default:
1
.use_eta (bool) β If
True
, add a learnable weight on each dimension in attention. Default:False
.attn_bef (bool) β If
True
, add another attention layer before propagation. Default:False
.dropout (float) β The dropout rate in mlp layers. Default:
0.0
.attn_dropout (float) β The dropout rate of attention values. Default:
0.0
.inp_dropout (float) β The dropout rate on input features. Default:
0.0
.
Note
add_self_loop
will be automatically called before propagation.Example
>>> import dgl >>> from dgl.nn import TWIRLSConv >>> import torch as th
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = TWIRLSConv(10, 2, 128, prop_step = 64) >>> res = conv(g , feat) >>> res.size() torch.Size([6, 2])
- forward(graph, feat)[source]ο
Descriptionο
Run TWIRLS forward.
- param graph:
The graph.
- type graph:
DGLGraph
- param feat:
The initial node features.
- type feat:
torch.Tensor
- returns:
The output feature
- rtype:
torch.Tensor
Note
Input shape: \((N, \text{input_d})\) where \(N\) is the number of nodes.
Output shape: \((N, \text{output_d})\).