PathEncoder

class dgl.nn.pytorch.gt.PathEncoder(max_len, feat_dim, num_heads=1)[source]

Bases: torch.nn.modules.module.Module

Path Encoder, as introduced in Edge Encoding of Do Transformers Really Perform Bad for Graph Representation?

This module is a learnable path embedding module and encodes the shortest path between each pair of nodes as attention bias.

Parameters
  • max_len (int) – Maximum number of edges in each path to be encoded. Exceeding part of each path will be truncated, i.e. truncating edges with serial number no less than max_len.

  • feat_dim (int) – Dimension of edge features in the input graph.

  • num_heads (int, optional) – Number of attention heads if multi-head attention mechanism is applied. Default : 1.

Examples

>>> import torch as th
>>> import dgl
>>> from dgl.nn import PathEncoder
>>> from dgl import shortest_dist
>>> g = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> edata = th.rand(8, 16)
>>> # Since shortest_dist returns -1 for unreachable node pairs,
>>> # edata[-1] should be filled with zero padding.
>>> edata = th.cat(
        (edata, th.zeros(1, 16)), dim=0
    )
>>> dist, path = shortest_dist(g, root=None, return_paths=True)
>>> path_data = edata[path[:, :, :2]]
>>> path_encoder = PathEncoder(2, 16, num_heads=8)
>>> out = path_encoder(dist.unsqueeze(0), path_data.unsqueeze(0))
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
forward(dist, path_data)[source]
Parameters
  • dist (Tensor) – Shortest path distance matrix of the batched graph with zero padding, of shape \((B, N, N)\), where \(B\) is the batch size of the batched graph, and \(N\) is the maximum number of nodes.

  • path_data (Tensor) – Edge feature along the shortest path with zero padding, of shape \((B, N, N, L, d)\), where \(L\) is the maximum length of the shortest paths, and \(d\) is feat_dim.

Returns

Return attention bias as path encoding, of shape \((B, N, N, H)\), where \(B\) is the batch size of the input graph, \(N\) is the maximum number of nodes, and \(H\) is num_heads.

Return type

torch.Tensor