LapPosEncoder

class dgl.nn.pytorch.gt.LapPosEncoder(model_type, num_layer, k, dim, n_head=1, batch_norm=False, num_post_layer=0)[source]

Bases: torch.nn.modules.module.Module

Laplacian Positional Encoder (LPE), as introduced in GraphGPS: General Powerful Scalable Graph Transformers

This module is a learned laplacian positional encoding module using Transformer or DeepSet.

Parameters
  • model_type (str) – Encoder model type for LPE, can only be “Transformer” or “DeepSet”.

  • num_layer (int) – Number of layers in Transformer/DeepSet Encoder.

  • k (int) – Number of smallest non-trivial eigenvectors.

  • dim (int) – Output size of final laplacian encoding.

  • n_head (int, optional) – Number of heads in Transformer Encoder. Default : 1.

  • batch_norm (bool, optional) – If True, apply batch normalization on raw laplacian positional encoding. Default : False.

  • num_post_layer (int, optional) – If num_post_layer > 0, apply an MLP of num_post_layer layers after pooling. Default : 0.

Example

>>> import dgl
>>> from dgl import LapPE
>>> from dgl.nn import LapPosEncoder
>>> transform = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g = transform(g)
>>> eigvals, eigvecs = g.ndata['eigval'], g.ndata['eigvec']
>>> transformer_encoder = LapPosEncoder(
        model_type="Transformer", num_layer=3, k=5, dim=16, n_head=4
    )
>>> pos_encoding = transformer_encoder(eigvals, eigvecs)
>>> deepset_encoder = LapPosEncoder(
        model_type="DeepSet", num_layer=3, k=5, dim=16, num_post_layer=2
    )
>>> pos_encoding = deepset_encoder(eigvals, eigvecs)
forward(eigvals, eigvecs)[source]
Parameters
  • eigvals (Tensor) – Laplacian Eigenvalues of shape \((N, k)\), k different eigenvalues repeat N times, can be obtained by using LaplacianPE.

  • eigvecs (Tensor) – Laplacian Eigenvectors of shape \((N, k)\), can be obtained by using LaplacianPE.

Returns

Return the laplacian positional encodings of shape \((N, d)\), where \(N\) is the number of nodes in the input graph, \(d\) is dim.

Return type

Tensor