Note

Click here to download the full example code

# Understand Graph Attention Network¶

**Authors:** Hao Zhang, Mufei Li, Minjie Wang Zheng Zhang

From Graph Convolutional Network (GCN), we learned that combining local graph structure and node-level features yields good performance on node classification task. However, the way GCN aggregates is structure-dependent, which may hurt its generalizability.

One workaround is to simply average over all neighbor node features as in GraphSAGE. Graph Attention Network proposes an alternative way by weighting neighbor features with feature dependent and structure free normalization, in the style of attention.

The goal of this tutorial:

- Explain what is Graph Attention Network.
- Demonstrate how it can be implemented in DGL.
- Understand the attentions learnt.
- Introduce to inductive learning.

## Introducing Attention to GCN¶

The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated.

For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors:

where \(\mathcal{N}(i)\) is the set of its one-hop neighbors (to include \(v_i\) in the set, simply add a self-loop to each node), \(c_{ij}=\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}\) is a normalization constant based on graph structure, \(\sigma\) is an activation function (GCN uses ReLU), and \(W^{(l)}\) is a shared weight matrix for node-wise feature transformation. Another model proposed in GraphSAGE employs the same update rule except that they set \(c_{ij}=|\mathcal{N}(i)|\).

GAT introduces the attention mechanism as a substitute for the statically normalized convolution operation. Below are the equations to compute the node embedding \(h_i^{(l+1)}\) of layer \(l+1\) from the embeddings of layer \(l\):

Explanations:

- Equation (1) is a linear transformation of the lower layer embedding \(h_i^{(l)}\) and \(W^{(l)}\) is its learnable weight matrix.
- Equation (2) computes a pair-wise
*unnormalized*attention score between two neighbors. Here, it first concatenates the \(z\) embeddings of the two nodes, where \(||\) denotes concatenation, then takes a dot product of it and a learnable weight vector \(\vec a^{(l)}\), and applies a LeakyReLU in the end. This form of attention is usually called*additive attention*, contrast with the dot-product attention in the Transformer model. - Equation (3) applies a softmax to normalize the attention scores on each node’s in-coming edges.
- Equation (4) is similar to GCN. The embeddings from neighbors are aggregated together, scaled by the attention scores.

There are other details from the paper, such as dropout and skip connections. For the purpose of simplicity, we omit them in this tutorial and leave the link to the full example at the end for interested readers.

In its essence, GAT is just a different aggregation function with attention over features of neighbors, instead of a simple mean aggregation.

## GAT in DGL¶

Let’s first have an overall impression about how a `GATLayer`

module is
implemented in DGL. Don’t worry, we will break down the four equations above
one-by-one.

```
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
# equation (1)
self.fc = nn.Linear(in_dim, out_dim, bias=False)
# equation (2)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
return {'z': edges.src['z'], 'e': edges.data['e']}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h': h}
def forward(self, h):
# equation (1)
z = self.fc(h)
self.g.ndata['z'] = z
# equation (2)
self.g.apply_edges(self.edge_attention)
# equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h')
```

### Equation (1)¶

The first one is simple. Linear transformation is very common and can be
easily implemented in Pytorch using `torch.nn.Linear`

.

### Equation (2)¶

The unnormalized attention score \(e_{ij}\) is calculated using the
embeddings of adjacent nodes \(i\) and \(j\). This suggests that the
attention scores can be viewed as edge data which can be calculated by the
`apply_edges`

API. The argument to the `apply_edges`

is an **Edge UDF**,
which is defined as below:

```
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)}
```

Here, the dot product with the learnable weight vector \(\vec{a^{(l)}}\)
is implemented again using pytorch’s linear transformation `attn_fc`

. Note
that `apply_edges`

will **batch** all the edge data in one tensor, so the
`cat`

, `attn_fc`

here are applied on all the edges in parallel.

### Equation (3) & (4)¶

Similar to GCN, `update_all`

API is used to trigger message passing on all
the nodes. The message function sends out two tensors: the transformed `z`

embedding of the source node and the unnormalized attention score `e`

on
each edge. The reduce function then performs two tasks:

- Normalize the attention scores using softmax (equation (3)).
- Aggregate neighbor embeddings weighted by the attention scores (equation(4)).

Both tasks first fetch data from the mailbox and then manipulate it on the
second dimension (`dim=1`

), on which the messages are batched.

```
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h' : h}
```

### Multi-head Attention¶

Analogous to multiple channels in ConvNet, GAT introduces **multi-head
attention** to enrich the model capacity and to stabilize the learning
process. Each attention head has its own parameters and their outputs can be
merged in two ways:

or

where \(K\) is the number of heads. The authors suggest using concatenation for intermediary layers and average for the final layer.

We can use the above defined single-head `GATLayer`

as the building block
for the `MultiHeadGATLayer`

below:

```
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
# concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1)
else:
# merge using average
return torch.mean(torch.stack(head_outs))
```

### Put everything together¶

Now, we can define a two-layer GAT model:

```
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
# Be aware that the input dimension is hidden_dim*num_heads since
# multiple head outputs are concatenated together. Also, only
# one attention head in the output layer.
self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
def forward(self, h):
h = self.layer1(h)
h = F.elu(h)
h = self.layer2(h)
return h
```

We then load the cora dataset using DGL’s built-in data module.

```
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data():
data = citegrh.load_cora()
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
g = data.graph
# add self loop
g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
return g, features, labels, mask
```

The training loop is exactly the same as in the GCN tutorial.

```
import time
import numpy as np
g, features, labels, mask = load_cora_data()
# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
in_dim=features.size()[1],
hidden_dim=8,
out_dim=7,
num_heads=2)
# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# main loop
dur = []
for epoch in range(30):
if epoch >= 3:
t0 = time.time()
logits = net(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)))
```

Out:

```
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9460 | Time(s) nan
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00001 | Loss 1.9453 | Time(s) nan
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00002 | Loss 1.9445 | Time(s) nan
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00003 | Loss 1.9437 | Time(s) 0.2247
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00004 | Loss 1.9429 | Time(s) 0.2198
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00005 | Loss 1.9421 | Time(s) 0.2179
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00006 | Loss 1.9413 | Time(s) 0.2174
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00007 | Loss 1.9405 | Time(s) 0.2168
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00008 | Loss 1.9397 | Time(s) 0.2164
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00009 | Loss 1.9388 | Time(s) 0.2168
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00010 | Loss 1.9380 | Time(s) 0.2170
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00011 | Loss 1.9371 | Time(s) 0.2167
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00012 | Loss 1.9362 | Time(s) 0.2166
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00013 | Loss 1.9353 | Time(s) 0.2169
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00014 | Loss 1.9344 | Time(s) 0.2171
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00015 | Loss 1.9335 | Time(s) 0.2174
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00016 | Loss 1.9325 | Time(s) 0.2177
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00017 | Loss 1.9316 | Time(s) 0.2176
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00018 | Loss 1.9306 | Time(s) 0.2183
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00019 | Loss 1.9296 | Time(s) 0.2185
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00020 | Loss 1.9286 | Time(s) 0.2186
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00021 | Loss 1.9276 | Time(s) 0.2187
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00022 | Loss 1.9265 | Time(s) 0.2189
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00023 | Loss 1.9254 | Time(s) 0.2189
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00024 | Loss 1.9244 | Time(s) 0.2190
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00025 | Loss 1.9233 | Time(s) 0.2190
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00026 | Loss 1.9221 | Time(s) 0.2192
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00027 | Loss 1.9210 | Time(s) 0.2192
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00028 | Loss 1.9198 | Time(s) 0.2192
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Epoch 00029 | Loss 1.9187 | Time(s) 0.2193
```

## Visualizing and Understanding Attention Learnt¶

### Cora¶

The following table summarizes the model performances on Cora reported in the GAT paper and obtained with dgl implementations.

Model | Accuracy |
---|---|

GCN (paper) | \(81.4\pm 0.5%\) |

GCN (dgl) | \(82.05\pm 0.33%\) |

GAT (paper) | \(83.0\pm 0.7%\) |

GAT (dgl) | \(83.69\pm 0.529%\) |

*What kind of attention distribution has our model learnt?*

Because the attention weight \(a_{ij}\) is associated with edges, we can
visualize it by coloring edges. Below we pick a subgraph of Cora and plot the
attention weights of the last `GATLayer`

. The nodes are colored according
to their labels, whereas the edges are colored according to the magnitude of
the attention weights, which can be referred with the colorbar on the right.

You can that the model seems to learn different attention weights. To understand the distribution more thoroughly, we measure the entropy) of the attention distribution. For any node \(i\), \(\{\alpha_{ij}\}_{j\in\mathcal{N}(i)}\) forms a discrete probability distribution over all its neighbors with the entropy given by

Intuitively, a low entropy means a high degree of concentration, and vice versa; an entropy of 0 means all attention is on one source node. The uniform distribution has the highest entropy of \(\log(\mathcal{N}(i))\). Ideally, we want to see the model learns a distribution of lower entropy (i.e, one or two neighbors are much more important than the others).

Note that since nodes can have different degrees, the maximum entropy will also be different. Therefore, we plot the aggregated histogram of entropy values of all nodes in the entire graph. Below are the attention histogram of learned by each attention head.

As a reference, here is the histogram if all the nodes have uniform attention weight distribution.

One can see that **the attention values learned is quite similar to uniform distribution**
(i.e, all neighbors are equally important). This partially
explains why the performance of GAT is close to that of GCN on Cora
(according to author’s reported result, the accuracy difference averaged
over 100 runs is less than 2%); attention does not matter
since it does not differentiate much any ways.

*Does that mean the attention mechanism is not useful?* No! A different
dataset exhibits an entirely different pattern, as we show next.

### Protein-Protein Interaction (PPI) networks¶

The PPI dataset used here consists of \(24\) graphs corresponding to different human tissues. Nodes can have up to \(121\) kinds of labels, so the label of node is represented as a binary tensor of size \(121\). The task is to predict node label.

We use \(20\) graphs for training, \(2\) for validation and \(2\) for test. The average number of nodes per graph is \(2372\). Each node has \(50\) features that are composed of positional gene sets, motif gene sets and immunological signatures. Critically, test graphs remain completely unobserved during training, a setting called “inductive learning”.

We compare the performance of GAT and GCN for \(10\) random runs on this task and use hyperparameter search on the validation set to find the best model.

Model | F1 Score(micro) |
---|---|

GAT | \(0.975 \pm 0.006\) |

GCN | \(0.509 \pm 0.025\) |

Paper | \(0.973 \pm 0.002\) |

The table above is the result of this experiment, where we use micro F1 score to evaluate the model performance.

Note

Below is the calculation process of F1 score:

- \(TP_{t}\) represents for number of nodes that both have and are predicted to have label \(t\)
- \(FP_{t}\) represents for number of nodes that do not have but are predicted to have label \(t\)
- \(FN_{t}\) represents for number of output classes labeled as \(t\) but predicted as others.
- \(n\) is the number of labels, i.e. \(121\) in our case.

During training, we use `BCEWithLogitsLoss`

as the loss function. The
learning curves of GAT and GCN are presented below; what is evident is the
dramatic performance adavantage of GAT over GCN.

As before, we can have a statistical understanding of the attentions learnt by showing the histogram plot for the node-wise attention entropy. Below are the attention histogram learnt by different attention layers.

*Attention learnt in layer 1:*

*Attention learnt in layer 2:*

*Attention learnt in final layer:*

Again, comparing with uniform distribution:

Clearly, **GAT does learn sharp attention weights**! There is a clear pattern
over the layers as well: **the attention gets sharper with higher
layer**.

Unlike the Cora dataset where GAT’s gain is lukewarm at best, for PPI there is a significant performance gap between GAT and other GNN variants compared in the GAT paper (at least 20%), and the attention distributions between the two clearly differ. While this deserves further research, one immediate conclusion is that GAT’s advantage lies perhaps more in its ability to handle a graph with more complex neighborhood structure.

## What’s Next?¶

So far, we demonstrated how to use DGL to implement GAT. There are some missing details such as dropout, skip connections and hyper-parameter tuning, which are common practices and do not involve DGL-related concepts. We refer interested readers to the full example.

- See the optimized full example here.
- Stay tune for our next tutorial about how to speedup GAT models by parallelizing multiple attention heads and SPMV optimization.

**Total running time of the script:** ( 0 minutes 11.278 seconds)