TWIRLSConv¶
-
class
dgl.nn.pytorch.conv.
TWIRLSConv
(input_d, output_d, hidden_d, prop_step, num_mlp_before=1, num_mlp_after=1, norm='none', precond=True, alp=0, lam=1, attention=False, tau=0.2, T=-1, p=1, use_eta=False, attn_bef=False, dropout=0.0, attn_dropout=0.0, inp_dropout=0.0)[source]¶ Bases:
torch.nn.modules.module.Module
Convolution together with iteratively reweighting least squre from Graph Neural Networks Inspired by Classical Iterative Algorithms
- Parameters
input_d (int) – Number of input features.
output_d (int) – Number of output features.
hidden_d (int) – Size of hidden layers.
prop_step (int) – Number of propagation steps
num_mlp_before (int) – Number of mlp layers before propagation. Default:
1
.num_mlp_after (int) – Number of mlp layers after propagation. Default:
1
.norm (str) – The type of norm layers inside mlp layers. Can be
'batch'
,'layer'
or'none'
. Default:'none'
precond (str) – If True, use pre conditioning and unormalized laplacian, else not use pre conditioning and use normalized laplacian. Default:
True
alp (float) – The \(\alpha\) in paper. If equal to \(0\), will be automatically decided based on other hyper prameters. Default:
0
.lam (float) – The \(\lambda\) in paper. Default:
1
.attention (bool) – If
True
, add an attention layer inside propagations. Default:False
.tau (float) – The \(\tau\) in paper. Default:
0.2
.T (float) – The \(T\) in paper. If < 0, \(T\) will be set to infty. Default:
-1
.p (float) – The \(p\) in paper. Default:
1
.use_eta (bool) – If
True
, add a learnable weight on each dimension in attention. Default:False
.attn_bef (bool) – If
True
, add another attention layer before propagation. Default:False
.dropout (float) – The dropout rate in mlp layers. Default:
0.0
.attn_dropout (float) – The dropout rate of attention values. Default:
0.0
.inp_dropout (float) – The dropout rate on input features. Default:
0.0
.
Note
add_self_loop
will be automatically called before propagation.Example
>>> import dgl >>> from dgl.nn import TWIRLSConv >>> import torch as th
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = TWIRLSConv(10, 2, 128, prop_step = 64) >>> res = conv(g , feat) >>> res.size() torch.Size([6, 2])
-
forward
(graph, feat)[source]¶ Run TWIRLS forward.
- Parameters
graph (DGLGraph) – The graph.
feat (torch.Tensor) – The initial node features.
- Returns
The output feature
- Return type
torch.Tensor
Note
Input shape: \((N, \text{input_d})\) where \(N\) is the number of nodes.
Output shape: \((N, \text{output_d})\).