SpatialEncoder¶
-
class
dgl.nn.pytorch.graph_transformer.
SpatialEncoder
(max_dist, num_heads=1)[source]¶ Bases:
torch.nn.modules.module.Module
Spatial Encoder, as introduced in Do Transformers Really Perform Bad for Graph Representation? This module is a learnable spatial embedding module which encodes the shortest distance between each node pair for attention bias.
- Parameters
Examples
>>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3]) >>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1]) >>> g = dgl.graph((u, v)) >>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8) >>> out = spatial_encoder(g) >>> print(out.shape) torch.Size([1, 4, 4, 8])
-
forward
(g)[source]¶ - Parameters
g (DGLGraph) – A DGLGraph to be encoded, which must be a homogeneous one.
- Returns
Return attention bias as spatial encoding of shape \((B, N, N, H)\), where \(N\) is the maximum number of nodes, \(B\) is the batch size of the input graph, and \(H\) is
num_heads
.- Return type
torch.Tensor