ChebConvΒΆ
-
class
dgl.nn.pytorch.conv.
ChebConv
(in_feats, out_feats, k, activation=<function relu>, bias=True)[source]ΒΆ Bases:
torch.nn.modules.module.Module
Chebyshev Spectral Graph Convolution layer from Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering
\[ \begin{align}\begin{aligned}h_i^{l+1} &= \sum_{k=0}^{K-1} W^{k, l}z_i^{k, l}\\Z^{0, l} &= H^{l}\\Z^{1, l} &= \tilde{L} \cdot H^{l}\\Z^{k, l} &= 2 \cdot \tilde{L} \cdot Z^{k-1, l} - Z^{k-2, l}\\\tilde{L} &= 2\left(I - \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\right)/\lambda_{max} - I\end{aligned}\end{align} \]where \(\tilde{A}\) is \(A\) + \(I\), \(W\) is learnable weight.
- Parameters
in_feats (int) β Dimension of input features; i.e, the number of dimensions of \(h_i^{(l)}\).
out_feats (int) β Dimension of output features \(h_i^{(l+1)}\).
k (int) β Chebyshev filter size \(K\).
activation (function, optional) β Activation function. Default
ReLu
.bias (bool, optional) β If True, adds a learnable bias to the output. Default:
True
.
Example
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import ChebConv >> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = ChebConv(10, 2, 2) >>> res = conv(g, feat) >>> res tensor([[ 0.6163, -0.1809], [ 0.6163, -0.1809], [ 0.6163, -0.1809], [ 0.9698, -1.5053], [ 0.3664, 0.7556], [-0.2370, 3.0164]], grad_fn=<AddBackward0>)
-
forward
(graph, feat, lambda_max=None)[source]ΒΆ Compute ChebNet layer.
- Parameters
graph (DGLGraph) β The graph.
feat (torch.Tensor) β The input feature of shape \((N, D_{in})\) where \(D_{in}\) is size of input feature, \(N\) is the number of nodes.
lambda_max (list or tensor or None, optional.) β
A list(tensor) with length \(B\), stores the largest eigenvalue of the normalized laplacian of each individual graph in
graph
, where \(B\) is the batch size of the input graph. Default: None.If None, this method would set the default value to 2. One can use
dgl.laplacian_lambda_max()
to compute this value.
- Returns
The output feature of shape \((N, D_{out})\) where \(D_{out}\) is size of output feature.
- Return type
torch.Tensor