Prepare Dataο
In this section, we will prepare the data for the Graphormer model introduced before. We can use any dataset containing DGLGraph
objects and standard PyTorch dataloader to feed the data to the model. The key is to define a collate function to group features of multiple graphs into batches. We show an example of the collate function as follows:
def collate(graphs):
# compute shortest path features, can be done in advance
for g in graphs:
spd, path = dgl.shortest_dist(g, root=None, return_paths=True)
g.ndata["spd"] = spd
g.ndata["path"] = path
num_graphs = len(graphs)
num_nodes = [g.num_nodes() for g in graphs]
max_num_nodes = max(num_nodes)
attn_mask = th.zeros(num_graphs, max_num_nodes, max_num_nodes)
node_feat = []
in_degree, out_degree = [], []
path_data = []
# Since shortest_dist returns -1 for unreachable node pairs and padded
# nodes are unreachable to others, distance relevant to padded nodes
# use -1 padding as well.
dist = -th.ones(
(num_graphs, max_num_nodes, max_num_nodes), dtype=th.long
)
for i in range(num_graphs):
# A binary mask where invalid positions are indicated by True.
# Avoid the case where all positions are invalid.
attn_mask[i, :, num_nodes[i] + 1 :] = 1
# +1 to distinguish padded non-existing nodes from real nodes
node_feat.append(graphs[i].ndata["feat"] + 1)
# 0 for padding
in_degree.append(
th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)
)
out_degree.append(
th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)
)
# Path padding to make all paths to the same length "max_len".
path = graphs[i].ndata["path"]
path_len = path.size(dim=2)
# shape of shortest_path: [n, n, max_len]
max_len = 5
if path_len >= max_len:
shortest_path = path[:, :, :max_len]
else:
p1d = (0, max_len - path_len)
# Use the same -1 padding as shortest_dist for
# invalid edge IDs.
shortest_path = th.nn.functional.pad(path, p1d, "constant", -1)
pad_num_nodes = max_num_nodes - num_nodes[i]
p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)
shortest_path = th.nn.functional.pad(shortest_path, p3d, "constant", -1)
# +1 to distinguish padded non-existing edges from real edges
edata = graphs[i].edata["feat"] + 1
# shortest_dist pads non-existing edges (at the end of shortest
# paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands
# for all padded edge features.
edata = th.cat(
(edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0
)
path_data.append(edata[shortest_path])
dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata["spd"]
# node feat padding
node_feat = th.nn.utils.rnn.pad_sequence(node_feat, batch_first=True)
# degree padding
in_degree = th.nn.utils.rnn.pad_sequence(in_degree, batch_first=True)
out_degree = th.nn.utils.rnn.pad_sequence(out_degree, batch_first=True)
return (
node_feat,
in_degree,
out_degree,
attn_mask,
th.stack(path_data),
dist,
)
In this example, we also omit details like the addition of a virtual node. For more details, please refer to the Graphormer example.