Spatial Encoder, as introduced in Do Transformers Really Perform Bad for Graph Representation?
This module is a learnable spatial embedding module, which encodes the shortest distance between each node pair for attention bias.
>>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder >>> from dgl import shortest_dist
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1])) >>> g2 = dgl.graph(([0,1], [1,0])) >>> n1, n2 = g1.num_nodes(), g2.num_nodes() >>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs >>> dist = -th.ones((2, 4, 4), dtype=th.long) >>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False) >>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False) >>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8) >>> out = spatial_encoder(dist) >>> print(out.shape) torch.Size([2, 4, 4, 8])
dist (Tensor) – Shortest path distance of the batched graph with -1 padding, a tensor of shape \((B, N, N)\), where \(B\) is the batch size of the batched graph, and \(N\) is the maximum number of nodes.
Return attention bias as spatial encoding of shape \((B, N, N, H)\), where \(H\) is
- Return type