DGL
1.0.x

Get Started

  • Install and Setup
  • A Blitz Introduction to DGL

Advanced Materials

  • User Guide
  • 用户指南
  • 사용자 가이드
  • 🆕 Tutorials: dgl.sparse
  • Stochastic Training of GNNs
    • Introduction of Neighbor Sampling for GNN Training
    • Training GNN with Neighbor Sampling for Node Classification
    • Stochastic Training of GNN for Link Prediction
    • Writing GNN Modules for Stochastic GNN Training
  • Training on CPUs
  • Training on Multiple GPUs
  • Distributed training
  • Paper Study with DGL

API Reference

  • dgl
  • dgl.data
  • dgl.dataloading
  • dgl.DGLGraph
  • dgl.distributed
  • dgl.function
  • dgl.geometry
  • dgl.nn (PyTorch)
  • dgl.nn (TensorFlow)
  • dgl.nn (MXNet)
  • dgl.nn.functional
  • dgl.ops
  • dgl.optim
  • dgl.sampling
  • 🆕 dgl.sparse
  • dgl.multiprocessing
  • dgl.transforms
  • User-defined Functions

Notes

  • Contribute to DGL
  • DGL Foreign Function Interface (FFI)
  • Performance Benchmarks

Misc

  • Frequently Asked Questions (FAQ)
  • Environment Variables
  • Resources
DGL
  • Stochastic Training of GNNs
  • Stochastic Training of GNN for Link Prediction
  • Edit on GitHub

Note

Click here to download the full example code

Stochastic Training of GNN for Link Prediction¶

This tutorial will show how to train a multi-layer GraphSAGE for link prediction on ogbn-arxiv provided by Open Graph Benchmark (OGB). The dataset contains around 170 thousand nodes and 1 million edges.

By the end of this tutorial, you will be able to

  • Train a GNN model for link prediction on a single GPU with DGL’s neighbor sampling components.

This tutorial assumes that you have read the Introduction of Neighbor Sampling for GNN Training and Neighbor Sampling for Node Classification.

Link Prediction Overview¶

Link prediction requires the model to predict the probability of existence of an edge. This tutorial does so by computing a dot product between the representations of both incident nodes.

\[\hat{y}_{u\sim v} = \sigma(h_u^T h_v)\]

It then minimizes the following binary cross entropy loss.

\[\mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)\]

This is identical to the link prediction formulation in the previous tutorial on link prediction.

Loading Dataset¶

This tutorial loads the dataset from the ogb package as in the previous tutorial.

import os
os.environ['DGLBACKEND'] = 'pytorch'
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset("ogbn-arxiv")
device = "cpu"  # change to 'cuda' for GPU

graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
print(graph)
print(node_labels)

node_features = graph.ndata["feat"]
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print("Number of classes:", num_classes)

idx_split = dataset.get_idx_split()
train_nids = idx_split["train"]
valid_nids = idx_split["valid"]
test_nids = idx_split["test"]

Out:

Graph(num_nodes=169343, num_edges=2332486,
      ndata_schemes={'year': Scheme(shape=(1,), dtype=torch.int64), 'feat': Scheme(shape=(128,), dtype=torch.float32)}
      edata_schemes={})
tensor([[ 4],
        [ 5],
        [28],
        ...,
        [10],
        [ 4],
        [ 1]])
Number of classes: 40

Defining Neighbor Sampler and Data Loader in DGL¶

Different from the link prediction tutorial for full graph, a common practice to train GNN on large graphs is to iterate over the edges in minibatches, since computing the probability of all edges is usually impossible. For each minibatch of edges, you compute the output representation of their incident nodes using neighbor sampling and GNN, in a similar fashion introduced in the large-scale node classification tutorial.

DGL provides dgl.dataloading.as_edge_prediction_sampler to iterate over edges for edge classification or link prediction tasks.

To perform link prediction, you need to specify a negative sampler. DGL provides builtin negative samplers such as dgl.dataloading.negative_sampler.Uniform. Here this tutorial uniformly draws 5 negative examples per positive example.

negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)

After defining the negative sampler, one can then define the edge data loader with neighbor sampling. To create an DataLoader for link prediction, provide a neighbor sampler object as well as the negative sampler object created above.

sampler = dgl.dataloading.NeighborSampler([4, 4])
sampler = dgl.dataloading.as_edge_prediction_sampler(
    sampler, negative_sampler=negative_sampler
)
train_dataloader = dgl.dataloading.DataLoader(
    # The following arguments are specific to DataLoader.
    graph,  # The graph
    torch.arange(graph.number_of_edges()),  # The edges to iterate over
    sampler,  # The neighbor sampler
    device=device,  # Put the MFGs on CPU or GPU
    # The following arguments are inherited from PyTorch DataLoader.
    batch_size=1024,  # Batch size
    shuffle=True,  # Whether to shuffle the nodes for every epoch
    drop_last=False,  # Whether to drop the last incomplete batch
    num_workers=0,  # Number of sampler processes
)

You can peek one minibatch from train_dataloader and see what it will give you.

input_nodes, pos_graph, neg_graph, mfgs = next(iter(train_dataloader))
print("Number of input nodes:", len(input_nodes))
print(
    "Positive graph # nodes:",
    pos_graph.number_of_nodes(),
    "# edges:",
    pos_graph.number_of_edges(),
)
print(
    "Negative graph # nodes:",
    neg_graph.number_of_nodes(),
    "# edges:",
    neg_graph.number_of_edges(),
)
print(mfgs)

Out:

/home/ubuntu/prod-doc/readthedocs.org/user_builds/dgl/checkouts/1.0.x/python/dgl/dataloading/dataloader.py:869: DGLWarning: Dataloader CPU affinity opt is not enabled, consider switching it on (see enable_cpu_affinity() or CPU best practices for DGL [https://docs.dgl.ai/tutorials/cpu/cpu_best_practises.html])
  dgl_warning(f'Dataloader CPU affinity opt is not enabled, consider switching it on '
Number of input nodes: 56949
Positive graph # nodes: 6881 # edges: 1024
Negative graph # nodes: 6881 # edges: 5120
[Block(num_src_nodes=56949, num_dst_nodes=23770, num_edges=88332), Block(num_src_nodes=23770, num_dst_nodes=6881, num_edges=24032)]

The example minibatch consists of four elements.

The first element is an ID tensor for the input nodes, i.e., nodes whose input features are needed on the first GNN layer for this minibatch.

The second element and the third element are the positive graph and the negative graph for this minibatch. The concept of positive and negative graphs have been introduced in the full-graph link prediction tutorial. In minibatch training, the positive graph and the negative graph only contain nodes necessary for computing the pair-wise scores of positive and negative examples in the current minibatch.

The last element is a list of MFGs storing the computation dependencies for each GNN layer. The MFGs are used to compute the GNN outputs of the nodes involved in positive/negative graph.

Defining Model for Node Representation¶

The model is almost identical to the one in the node classification tutorial. The only difference is that since you are doing link prediction, the output dimension will not be the number of classes in the dataset.

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv


class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type="mean")
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h_dst = x[: mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[: mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        return h


model = Model(num_features, 128).to(device)

Defining the Score Predictor for Edges¶

After getting the node representation necessary for the minibatch, the last thing to do is to predict the score of the edges and non-existent edges in the sampled minibatch.

The following score predictor, copied from the link prediction tutorial, takes a dot product between the incident nodes’ representations.

import dgl.function as fn


class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
            g.ndata["h"] = h
            # Compute a new edge feature named 'score' by a dot-product between the
            # source node feature 'h' and destination node feature 'h'.
            g.apply_edges(fn.u_dot_v("h", "h", "score"))
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
            return g.edata["score"][:, 0]

Evaluating Performance with Unsupervised Learning (Optional)¶

There are various ways to evaluate the performance of link prediction. This tutorial follows the practice of GraphSAGE paper. Basically, it first trains a GNN via link prediction, and get an embedding for each node. Then it trains a downstream classifier on top of this embedding and compute the accuracy as an assessment of the embedding quality.

To obtain the representations of all the nodes, this tutorial uses neighbor sampling as introduced in the node classification tutorial.

Note

If you would like to obtain node representations without neighbor sampling during inference, please refer to this user guide.

def inference(model, graph, node_features):
    with torch.no_grad():
        nodes = torch.arange(graph.number_of_nodes())

        sampler = dgl.dataloading.NeighborSampler([4, 4])
        train_dataloader = dgl.dataloading.DataLoader(
            graph,
            torch.arange(graph.number_of_nodes()),
            sampler,
            batch_size=1024,
            shuffle=False,
            drop_last=False,
            num_workers=4,
            device=device,
        )

        result = []
        for input_nodes, output_nodes, mfgs in train_dataloader:
            # feature copy from CPU to GPU takes place here
            inputs = mfgs[0].srcdata["feat"]
            result.append(model(mfgs, inputs))

        return torch.cat(result)


import sklearn.metrics


def evaluate(emb, label, train_nids, valid_nids, test_nids):
    classifier = nn.Linear(emb.shape[1], num_classes).to(device)
    opt = torch.optim.LBFGS(classifier.parameters())

    def compute_loss():
        pred = classifier(emb[train_nids].to(device))
        loss = F.cross_entropy(pred, label[train_nids].to(device))
        return loss

    def closure():
        loss = compute_loss()
        opt.zero_grad()
        loss.backward()
        return loss

    prev_loss = float("inf")
    for i in range(1000):
        opt.step(closure)
        with torch.no_grad():
            loss = compute_loss().item()
            if np.abs(loss - prev_loss) < 1e-4:
                print("Converges at iteration", i)
                break
            else:
                prev_loss = loss

    with torch.no_grad():
        pred = classifier(emb.to(device)).cpu()
        label = label
        valid_acc = sklearn.metrics.accuracy_score(
            label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1)
        )
        test_acc = sklearn.metrics.accuracy_score(
            label[test_nids].numpy(), pred[test_nids].numpy().argmax(1)
        )
    return valid_acc, test_acc

Defining Training Loop¶

The following initializes the model and defines the optimizer.

model = Model(node_features.shape[1], 128).to(device)
predictor = DotPredictor().to(device)
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))

The following is the training loop for link prediction and evaluation, and also saves the model that performs the best on the validation set:

import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = "model.pt"
for epoch in range(1):
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, pos_graph, neg_graph, mfgs) in enumerate(tq):
            # feature copy from CPU to GPU takes place here
            inputs = mfgs[0].srcdata["feat"]

            outputs = model(mfgs, inputs)
            pos_score = predictor(pos_graph, outputs)
            neg_score = predictor(neg_graph, outputs)

            score = torch.cat([pos_score, neg_score])
            label = torch.cat(
                [torch.ones_like(pos_score), torch.zeros_like(neg_score)]
            )
            loss = F.binary_cross_entropy_with_logits(score, label)

            opt.zero_grad()
            loss.backward()
            opt.step()

            tq.set_postfix({"loss": "%.03f" % loss.item()}, refresh=False)

            if (step + 1) % 500 == 0:
                model.eval()
                emb = inference(model, graph, node_features)
                valid_acc, test_acc = evaluate(
                    emb, node_labels, train_nids, valid_nids, test_nids
                )
                print(
                    "Epoch {} Validation Accuracy {} Test Accuracy {}".format(
                        epoch, valid_acc, test_acc
                    )
                )
                if best_accuracy < valid_acc:
                    best_accuracy = valid_acc
                    torch.save(model.state_dict(), best_model_path)
                model.train()

                # Note that this tutorial do not train the whole model to the end.
                break

Out:

  0%|          | 0/2278 [00:00<?, ?it/s]
  0%|          | 1/2278 [00:00<09:20,  4.06it/s, loss=27.404]
  0%|          | 2/2278 [00:00<07:22,  5.15it/s, loss=20.053]
  0%|          | 3/2278 [00:00<06:22,  5.95it/s, loss=14.413]
  0%|          | 4/2278 [00:00<06:17,  6.02it/s, loss=10.106]
  0%|          | 5/2278 [00:00<05:57,  6.35it/s, loss=7.032]
  0%|          | 6/2278 [00:00<05:46,  6.55it/s, loss=4.868]
  0%|          | 7/2278 [00:01<05:38,  6.70it/s, loss=3.422]
  0%|          | 8/2278 [00:01<05:34,  6.79it/s, loss=2.516]
  0%|          | 9/2278 [00:01<05:42,  6.62it/s, loss=1.999]
  0%|          | 10/2278 [00:01<05:49,  6.49it/s, loss=1.756]
  0%|          | 11/2278 [00:01<05:53,  6.41it/s, loss=1.589]
  1%|          | 12/2278 [00:01<05:43,  6.60it/s, loss=1.496]
  1%|          | 13/2278 [00:02<05:29,  6.88it/s, loss=1.409]
  1%|          | 14/2278 [00:02<05:17,  7.13it/s, loss=1.292]
  1%|          | 15/2278 [00:02<05:11,  7.26it/s, loss=1.179]
  1%|          | 16/2278 [00:02<05:26,  6.93it/s, loss=1.090]
  1%|          | 17/2278 [00:02<05:24,  6.96it/s, loss=1.014]
  1%|          | 18/2278 [00:02<05:36,  6.72it/s, loss=0.960]
  1%|          | 19/2278 [00:02<05:42,  6.60it/s, loss=0.893]
  1%|          | 20/2278 [00:03<05:30,  6.82it/s, loss=0.845]
  1%|          | 21/2278 [00:03<05:32,  6.79it/s, loss=0.826]
  1%|          | 22/2278 [00:03<05:34,  6.74it/s, loss=0.811]
  1%|1         | 23/2278 [00:03<05:36,  6.69it/s, loss=0.782]
  1%|1         | 24/2278 [00:03<05:39,  6.64it/s, loss=0.777]
  1%|1         | 25/2278 [00:03<05:49,  6.44it/s, loss=0.764]
  1%|1         | 26/2278 [00:03<05:32,  6.77it/s, loss=0.762]
  1%|1         | 27/2278 [00:04<05:32,  6.77it/s, loss=0.749]
  1%|1         | 28/2278 [00:04<05:50,  6.41it/s, loss=0.747]
  1%|1         | 29/2278 [00:04<05:45,  6.50it/s, loss=0.747]
  1%|1         | 30/2278 [00:04<05:42,  6.57it/s, loss=0.742]
  1%|1         | 31/2278 [00:04<05:48,  6.45it/s, loss=0.738]
  1%|1         | 32/2278 [00:04<05:48,  6.44it/s, loss=0.733]
  1%|1         | 33/2278 [00:05<05:44,  6.51it/s, loss=0.733]
  1%|1         | 34/2278 [00:05<05:37,  6.65it/s, loss=0.726]
  2%|1         | 35/2278 [00:05<05:40,  6.59it/s, loss=0.722]
  2%|1         | 36/2278 [00:05<05:40,  6.59it/s, loss=0.725]
  2%|1         | 37/2278 [00:05<05:50,  6.39it/s, loss=0.721]
  2%|1         | 38/2278 [00:05<05:46,  6.47it/s, loss=0.714]
  2%|1         | 39/2278 [00:05<05:45,  6.47it/s, loss=0.709]
  2%|1         | 40/2278 [00:06<05:42,  6.53it/s, loss=0.707]
  2%|1         | 41/2278 [00:06<05:36,  6.65it/s, loss=0.702]
  2%|1         | 42/2278 [00:06<05:33,  6.71it/s, loss=0.699]
  2%|1         | 43/2278 [00:06<05:33,  6.69it/s, loss=0.703]
  2%|1         | 44/2278 [00:06<05:46,  6.45it/s, loss=0.694]
  2%|1         | 45/2278 [00:06<05:39,  6.58it/s, loss=0.695]
  2%|2         | 46/2278 [00:07<05:41,  6.53it/s, loss=0.689]
  2%|2         | 47/2278 [00:07<05:24,  6.87it/s, loss=0.687]
  2%|2         | 48/2278 [00:07<05:43,  6.50it/s, loss=0.688]
  2%|2         | 49/2278 [00:07<05:44,  6.47it/s, loss=0.687]
  2%|2         | 50/2278 [00:07<05:35,  6.63it/s, loss=0.681]
  2%|2         | 51/2278 [00:07<05:30,  6.75it/s, loss=0.680]
  2%|2         | 52/2278 [00:07<05:45,  6.43it/s, loss=0.679]
  2%|2         | 53/2278 [00:08<05:38,  6.57it/s, loss=0.678]
  2%|2         | 54/2278 [00:08<05:25,  6.84it/s, loss=0.684]
  2%|2         | 55/2278 [00:08<05:22,  6.89it/s, loss=0.681]
  2%|2         | 56/2278 [00:08<05:38,  6.57it/s, loss=0.679]
  3%|2         | 57/2278 [00:08<05:28,  6.75it/s, loss=0.677]
  3%|2         | 58/2278 [00:08<05:29,  6.73it/s, loss=0.681]
  3%|2         | 59/2278 [00:08<05:18,  6.97it/s, loss=0.672]
  3%|2         | 60/2278 [00:09<05:14,  7.06it/s, loss=0.677]
  3%|2         | 61/2278 [00:09<05:06,  7.24it/s, loss=0.676]
  3%|2         | 62/2278 [00:09<05:21,  6.89it/s, loss=0.676]
  3%|2         | 63/2278 [00:09<05:22,  6.87it/s, loss=0.676]
  3%|2         | 64/2278 [00:09<05:24,  6.82it/s, loss=0.678]
  3%|2         | 65/2278 [00:09<05:12,  7.09it/s, loss=0.673]
  3%|2         | 66/2278 [00:09<05:04,  7.27it/s, loss=0.672]
  3%|2         | 67/2278 [00:10<05:09,  7.13it/s, loss=0.672]
  3%|2         | 68/2278 [00:10<05:02,  7.30it/s, loss=0.673]
  3%|3         | 69/2278 [00:10<04:56,  7.45it/s, loss=0.669]
  3%|3         | 70/2278 [00:10<04:59,  7.38it/s, loss=0.671]
  3%|3         | 71/2278 [00:10<05:03,  7.27it/s, loss=0.671]
  3%|3         | 72/2278 [00:10<05:05,  7.23it/s, loss=0.675]
  3%|3         | 73/2278 [00:10<05:12,  7.06it/s, loss=0.674]
  3%|3         | 74/2278 [00:11<05:16,  6.97it/s, loss=0.668]
  3%|3         | 75/2278 [00:11<05:24,  6.78it/s, loss=0.672]
  3%|3         | 76/2278 [00:11<05:17,  6.93it/s, loss=0.673]
  3%|3         | 77/2278 [00:11<05:07,  7.16it/s, loss=0.670]
  3%|3         | 78/2278 [00:11<05:02,  7.28it/s, loss=0.667]
  3%|3         | 79/2278 [00:11<05:04,  7.21it/s, loss=0.667]
  4%|3         | 80/2278 [00:11<05:12,  7.04it/s, loss=0.674]
  4%|3         | 81/2278 [00:12<05:16,  6.95it/s, loss=0.668]
  4%|3         | 82/2278 [00:12<05:17,  6.91it/s, loss=0.669]
  4%|3         | 83/2278 [00:12<05:12,  7.01it/s, loss=0.671]
  4%|3         | 84/2278 [00:12<05:10,  7.08it/s, loss=0.667]
  4%|3         | 85/2278 [00:12<05:11,  7.03it/s, loss=0.665]
  4%|3         | 86/2278 [00:12<05:13,  7.00it/s, loss=0.669]
  4%|3         | 87/2278 [00:12<05:16,  6.92it/s, loss=0.674]
  4%|3         | 88/2278 [00:13<05:19,  6.86it/s, loss=0.669]
  4%|3         | 89/2278 [00:13<05:35,  6.53it/s, loss=0.670]
  4%|3         | 90/2278 [00:13<05:25,  6.72it/s, loss=0.664]
  4%|3         | 91/2278 [00:13<05:26,  6.70it/s, loss=0.668]
  4%|4         | 92/2278 [00:13<05:19,  6.84it/s, loss=0.669]
  4%|4         | 93/2278 [00:13<05:11,  7.02it/s, loss=0.668]
  4%|4         | 94/2278 [00:13<05:07,  7.10it/s, loss=0.668]
  4%|4         | 95/2278 [00:14<05:14,  6.94it/s, loss=0.670]
  4%|4         | 96/2278 [00:14<05:25,  6.70it/s, loss=0.672]
  4%|4         | 97/2278 [00:14<05:20,  6.80it/s, loss=0.670]
  4%|4         | 98/2278 [00:14<05:25,  6.69it/s, loss=0.668]
  4%|4         | 99/2278 [00:14<05:12,  6.98it/s, loss=0.667]
  4%|4         | 100/2278 [00:14<05:13,  6.95it/s, loss=0.667]
  4%|4         | 101/2278 [00:14<05:20,  6.78it/s, loss=0.669]
  4%|4         | 102/2278 [00:15<05:22,  6.74it/s, loss=0.670]
  5%|4         | 103/2278 [00:15<05:24,  6.70it/s, loss=0.667]
  5%|4         | 104/2278 [00:15<05:13,  6.93it/s, loss=0.665]
  5%|4         | 105/2278 [00:15<05:10,  7.00it/s, loss=0.668]
  5%|4         | 106/2278 [00:15<05:00,  7.22it/s, loss=0.664]
  5%|4         | 107/2278 [00:15<05:02,  7.17it/s, loss=0.667]
  5%|4         | 108/2278 [00:15<05:02,  7.18it/s, loss=0.670]
  5%|4         | 109/2278 [00:16<05:13,  6.91it/s, loss=0.660]
  5%|4         | 110/2278 [00:16<05:21,  6.75it/s, loss=0.666]
  5%|4         | 111/2278 [00:16<05:16,  6.85it/s, loss=0.667]
  5%|4         | 112/2278 [00:16<05:12,  6.94it/s, loss=0.663]
  5%|4         | 113/2278 [00:16<05:03,  7.13it/s, loss=0.667]
  5%|5         | 114/2278 [00:16<05:06,  7.06it/s, loss=0.671]
  5%|5         | 115/2278 [00:16<04:58,  7.25it/s, loss=0.665]
  5%|5         | 116/2278 [00:17<05:01,  7.17it/s, loss=0.664]
  5%|5         | 117/2278 [00:17<04:54,  7.34it/s, loss=0.664]
  5%|5         | 118/2278 [00:17<05:03,  7.11it/s, loss=0.662]
  5%|5         | 119/2278 [00:17<05:02,  7.14it/s, loss=0.666]
  5%|5         | 120/2278 [00:17<05:08,  6.99it/s, loss=0.667]
  5%|5         | 121/2278 [00:17<05:02,  7.13it/s, loss=0.668]
  5%|5         | 122/2278 [00:17<05:02,  7.12it/s, loss=0.661]
  5%|5         | 123/2278 [00:18<04:56,  7.26it/s, loss=0.667]
  5%|5         | 124/2278 [00:18<05:09,  6.95it/s, loss=0.669]
  5%|5         | 125/2278 [00:18<05:07,  6.99it/s, loss=0.667]
  6%|5         | 126/2278 [00:18<05:19,  6.73it/s, loss=0.666]
  6%|5         | 127/2278 [00:18<05:13,  6.86it/s, loss=0.667]
  6%|5         | 128/2278 [00:18<05:16,  6.79it/s, loss=0.665]
  6%|5         | 129/2278 [00:18<05:14,  6.83it/s, loss=0.665]
  6%|5         | 130/2278 [00:19<05:21,  6.68it/s, loss=0.663]
  6%|5         | 131/2278 [00:19<05:37,  6.36it/s, loss=0.667]
  6%|5         | 132/2278 [00:19<05:51,  6.10it/s, loss=0.668]
  6%|5         | 133/2278 [00:19<05:43,  6.24it/s, loss=0.667]
  6%|5         | 134/2278 [00:19<05:37,  6.35it/s, loss=0.664]
  6%|5         | 135/2278 [00:19<05:32,  6.45it/s, loss=0.665]
  6%|5         | 136/2278 [00:20<05:35,  6.38it/s, loss=0.670]
  6%|6         | 137/2278 [00:20<05:33,  6.42it/s, loss=0.665]
  6%|6         | 138/2278 [00:20<05:35,  6.37it/s, loss=0.667]
  6%|6         | 139/2278 [00:20<05:29,  6.50it/s, loss=0.666]
  6%|6         | 140/2278 [00:20<05:16,  6.75it/s, loss=0.660]
  6%|6         | 141/2278 [00:20<05:11,  6.86it/s, loss=0.663]
  6%|6         | 142/2278 [00:20<05:08,  6.92it/s, loss=0.667]
  6%|6         | 143/2278 [00:21<05:14,  6.79it/s, loss=0.665]
  6%|6         | 144/2278 [00:21<05:02,  7.05it/s, loss=0.662]
  6%|6         | 145/2278 [00:21<05:07,  6.94it/s, loss=0.667]
  6%|6         | 146/2278 [00:21<05:24,  6.57it/s, loss=0.664]
  6%|6         | 147/2278 [00:21<05:35,  6.36it/s, loss=0.663]
  6%|6         | 148/2278 [00:21<05:23,  6.58it/s, loss=0.663]
  7%|6         | 149/2278 [00:22<05:30,  6.44it/s, loss=0.668]
  7%|6         | 150/2278 [00:22<05:29,  6.45it/s, loss=0.666]
  7%|6         | 151/2278 [00:22<05:12,  6.80it/s, loss=0.662]
  7%|6         | 152/2278 [00:22<05:22,  6.60it/s, loss=0.664]
  7%|6         | 153/2278 [00:22<05:28,  6.47it/s, loss=0.666]
  7%|6         | 154/2278 [00:22<05:17,  6.68it/s, loss=0.667]
  7%|6         | 155/2278 [00:22<05:12,  6.79it/s, loss=0.662]
  7%|6         | 156/2278 [00:23<05:00,  7.07it/s, loss=0.656]
  7%|6         | 157/2278 [00:23<04:51,  7.27it/s, loss=0.663]
  7%|6         | 158/2278 [00:23<04:46,  7.41it/s, loss=0.663]
  7%|6         | 159/2278 [00:23<05:08,  6.86it/s, loss=0.663]
  7%|7         | 160/2278 [00:23<05:13,  6.75it/s, loss=0.662]
  7%|7         | 161/2278 [00:23<05:32,  6.37it/s, loss=0.661]
  7%|7         | 162/2278 [00:23<05:17,  6.67it/s, loss=0.665]
  7%|7         | 163/2278 [00:24<05:12,  6.76it/s, loss=0.666]
  7%|7         | 164/2278 [00:24<05:09,  6.84it/s, loss=0.660]
  7%|7         | 165/2278 [00:24<05:06,  6.89it/s, loss=0.664]
  7%|7         | 166/2278 [00:24<05:33,  6.32it/s, loss=0.657]
  7%|7         | 167/2278 [00:24<05:23,  6.52it/s, loss=0.666]
  7%|7         | 168/2278 [00:24<05:20,  6.58it/s, loss=0.664]
  7%|7         | 169/2278 [00:24<05:12,  6.76it/s, loss=0.664]
  7%|7         | 170/2278 [00:25<05:06,  6.89it/s, loss=0.665]
  8%|7         | 171/2278 [00:25<05:06,  6.87it/s, loss=0.666]
  8%|7         | 172/2278 [00:25<05:00,  7.00it/s, loss=0.664]
  8%|7         | 173/2278 [00:25<05:12,  6.73it/s, loss=0.664]
  8%|7         | 174/2278 [00:25<05:00,  7.00it/s, loss=0.662]
  8%|7         | 175/2278 [00:25<05:00,  7.01it/s, loss=0.664]
  8%|7         | 176/2278 [00:25<05:08,  6.82it/s, loss=0.663]
  8%|7         | 177/2278 [00:26<05:16,  6.65it/s, loss=0.668]
  8%|7         | 178/2278 [00:26<05:15,  6.66it/s, loss=0.663]
  8%|7         | 179/2278 [00:26<05:20,  6.54it/s, loss=0.665]
  8%|7         | 180/2278 [00:26<05:24,  6.47it/s, loss=0.662]
  8%|7         | 181/2278 [00:26<05:20,  6.54it/s, loss=0.662]
  8%|7         | 182/2278 [00:26<05:18,  6.58it/s, loss=0.664]
  8%|8         | 183/2278 [00:27<05:04,  6.87it/s, loss=0.663]
  8%|8         | 184/2278 [00:27<05:14,  6.66it/s, loss=0.665]
  8%|8         | 185/2278 [00:27<05:11,  6.73it/s, loss=0.666]
  8%|8         | 186/2278 [00:27<05:21,  6.50it/s, loss=0.668]
  8%|8         | 187/2278 [00:27<05:20,  6.53it/s, loss=0.666]
  8%|8         | 188/2278 [00:27<05:36,  6.21it/s, loss=0.664]
  8%|8         | 189/2278 [00:27<05:30,  6.32it/s, loss=0.664]
  8%|8         | 190/2278 [00:28<05:36,  6.20it/s, loss=0.666]
  8%|8         | 191/2278 [00:28<05:26,  6.39it/s, loss=0.662]
  8%|8         | 192/2278 [00:28<05:31,  6.29it/s, loss=0.664]
  8%|8         | 193/2278 [00:28<05:15,  6.61it/s, loss=0.667]
  9%|8         | 194/2278 [00:28<05:23,  6.45it/s, loss=0.666]
  9%|8         | 195/2278 [00:28<05:22,  6.46it/s, loss=0.662]
  9%|8         | 196/2278 [00:29<05:19,  6.52it/s, loss=0.660]
  9%|8         | 197/2278 [00:29<05:09,  6.73it/s, loss=0.665]
  9%|8         | 198/2278 [00:29<05:09,  6.73it/s, loss=0.664]
  9%|8         | 199/2278 [00:29<04:56,  7.00it/s, loss=0.659]
  9%|8         | 200/2278 [00:29<04:59,  6.93it/s, loss=0.663]
  9%|8         | 201/2278 [00:29<05:06,  6.78it/s, loss=0.663]
  9%|8         | 202/2278 [00:29<05:07,  6.75it/s, loss=0.665]
  9%|8         | 203/2278 [00:30<05:01,  6.88it/s, loss=0.660]
  9%|8         | 204/2278 [00:30<04:58,  6.94it/s, loss=0.660]
  9%|8         | 205/2278 [00:30<04:56,  7.00it/s, loss=0.658]
  9%|9         | 206/2278 [00:30<04:46,  7.22it/s, loss=0.659]
  9%|9         | 207/2278 [00:30<04:50,  7.13it/s, loss=0.664]
  9%|9         | 208/2278 [00:30<05:05,  6.78it/s, loss=0.664]
  9%|9         | 209/2278 [00:30<05:09,  6.68it/s, loss=0.661]
  9%|9         | 210/2278 [00:31<05:06,  6.74it/s, loss=0.664]
  9%|9         | 211/2278 [00:31<05:10,  6.65it/s, loss=0.662]
  9%|9         | 212/2278 [00:31<05:04,  6.79it/s, loss=0.663]
  9%|9         | 213/2278 [00:31<04:59,  6.89it/s, loss=0.664]
  9%|9         | 214/2278 [00:31<05:02,  6.82it/s, loss=0.659]
  9%|9         | 215/2278 [00:31<05:07,  6.71it/s, loss=0.668]
  9%|9         | 216/2278 [00:32<05:16,  6.52it/s, loss=0.659]
 10%|9         | 217/2278 [00:32<05:15,  6.53it/s, loss=0.662]
 10%|9         | 218/2278 [00:32<05:14,  6.55it/s, loss=0.658]
 10%|9         | 219/2278 [00:32<05:14,  6.54it/s, loss=0.656]
 10%|9         | 220/2278 [00:32<05:13,  6.56it/s, loss=0.663]
 10%|9         | 221/2278 [00:32<05:15,  6.53it/s, loss=0.661]
 10%|9         | 222/2278 [00:32<05:16,  6.50it/s, loss=0.660]
 10%|9         | 223/2278 [00:33<05:05,  6.72it/s, loss=0.657]
 10%|9         | 224/2278 [00:33<05:08,  6.66it/s, loss=0.661]
 10%|9         | 225/2278 [00:33<05:02,  6.78it/s, loss=0.664]
 10%|9         | 226/2278 [00:33<04:51,  7.05it/s, loss=0.660]
 10%|9         | 227/2278 [00:33<04:55,  6.93it/s, loss=0.664]
 10%|#         | 228/2278 [00:33<04:50,  7.05it/s, loss=0.656]
 10%|#         | 229/2278 [00:33<05:05,  6.71it/s, loss=0.664]
 10%|#         | 230/2278 [00:34<04:58,  6.86it/s, loss=0.662]
 10%|#         | 231/2278 [00:34<04:57,  6.87it/s, loss=0.665]
 10%|#         | 232/2278 [00:34<04:53,  6.97it/s, loss=0.661]
 10%|#         | 233/2278 [00:34<05:00,  6.81it/s, loss=0.665]
 10%|#         | 234/2278 [00:34<04:56,  6.90it/s, loss=0.664]
 10%|#         | 235/2278 [00:34<05:08,  6.63it/s, loss=0.659]
 10%|#         | 236/2278 [00:34<05:11,  6.56it/s, loss=0.662]
 10%|#         | 237/2278 [00:35<05:14,  6.50it/s, loss=0.666]
 10%|#         | 238/2278 [00:35<05:05,  6.68it/s, loss=0.651]
 10%|#         | 239/2278 [00:35<05:02,  6.74it/s, loss=0.660]
 11%|#         | 240/2278 [00:35<04:56,  6.87it/s, loss=0.668]
 11%|#         | 241/2278 [00:35<04:55,  6.90it/s, loss=0.663]
 11%|#         | 242/2278 [00:35<04:45,  7.13it/s, loss=0.655]
 11%|#         | 243/2278 [00:35<04:51,  6.99it/s, loss=0.653]
 11%|#         | 244/2278 [00:36<04:54,  6.90it/s, loss=0.663]
 11%|#         | 245/2278 [00:36<04:49,  7.03it/s, loss=0.662]
 11%|#         | 246/2278 [00:36<04:48,  7.05it/s, loss=0.655]
 11%|#         | 247/2278 [00:36<04:39,  7.26it/s, loss=0.664]
 11%|#         | 248/2278 [00:36<04:39,  7.25it/s, loss=0.658]
 11%|#         | 249/2278 [00:36<04:45,  7.10it/s, loss=0.660]
 11%|#         | 250/2278 [00:36<04:49,  7.01it/s, loss=0.664]
 11%|#1        | 251/2278 [00:37<04:55,  6.85it/s, loss=0.662]
 11%|#1        | 252/2278 [00:37<04:46,  7.08it/s, loss=0.661]
 11%|#1        | 253/2278 [00:37<04:47,  7.05it/s, loss=0.661]
 11%|#1        | 254/2278 [00:37<04:38,  7.26it/s, loss=0.658]
 11%|#1        | 255/2278 [00:37<04:42,  7.16it/s, loss=0.661]
 11%|#1        | 256/2278 [00:37<04:42,  7.15it/s, loss=0.652]
 11%|#1        | 257/2278 [00:37<05:00,  6.73it/s, loss=0.662]
 11%|#1        | 258/2278 [00:38<04:56,  6.80it/s, loss=0.659]
 11%|#1        | 259/2278 [00:38<05:02,  6.68it/s, loss=0.665]
 11%|#1        | 260/2278 [00:38<04:57,  6.77it/s, loss=0.657]
 11%|#1        | 261/2278 [00:38<04:53,  6.87it/s, loss=0.667]
 12%|#1        | 262/2278 [00:38<04:50,  6.95it/s, loss=0.661]
 12%|#1        | 263/2278 [00:38<04:43,  7.11it/s, loss=0.659]
 12%|#1        | 264/2278 [00:38<04:43,  7.10it/s, loss=0.658]
 12%|#1        | 265/2278 [00:39<04:43,  7.10it/s, loss=0.661]
 12%|#1        | 266/2278 [00:39<04:42,  7.11it/s, loss=0.660]
 12%|#1        | 267/2278 [00:39<04:35,  7.29it/s, loss=0.660]
 12%|#1        | 268/2278 [00:39<04:37,  7.24it/s, loss=0.660]
 12%|#1        | 269/2278 [00:39<04:41,  7.14it/s, loss=0.662]
 12%|#1        | 270/2278 [00:39<04:46,  7.00it/s, loss=0.663]
 12%|#1        | 271/2278 [00:39<04:38,  7.21it/s, loss=0.659]
 12%|#1        | 272/2278 [00:40<04:45,  7.02it/s, loss=0.660]
 12%|#1        | 273/2278 [00:40<04:39,  7.18it/s, loss=0.665]
 12%|#2        | 274/2278 [00:40<04:32,  7.35it/s, loss=0.661]
 12%|#2        | 275/2278 [00:40<04:28,  7.47it/s, loss=0.656]
 12%|#2        | 276/2278 [00:40<04:32,  7.35it/s, loss=0.656]
 12%|#2        | 277/2278 [00:40<04:43,  7.05it/s, loss=0.656]
 12%|#2        | 278/2278 [00:40<04:44,  7.04it/s, loss=0.663]
 12%|#2        | 279/2278 [00:41<04:47,  6.96it/s, loss=0.661]
 12%|#2        | 280/2278 [00:41<04:53,  6.82it/s, loss=0.663]
 12%|#2        | 281/2278 [00:41<04:56,  6.74it/s, loss=0.664]
 12%|#2        | 282/2278 [00:41<04:49,  6.89it/s, loss=0.662]
 12%|#2        | 283/2278 [00:41<05:07,  6.49it/s, loss=0.661]
 12%|#2        | 284/2278 [00:41<05:12,  6.38it/s, loss=0.658]
 13%|#2        | 285/2278 [00:42<05:14,  6.33it/s, loss=0.654]
 13%|#2        | 286/2278 [00:42<05:16,  6.30it/s, loss=0.661]
 13%|#2        | 287/2278 [00:42<05:18,  6.26it/s, loss=0.660]
 13%|#2        | 288/2278 [00:42<05:06,  6.49it/s, loss=0.658]
 13%|#2        | 289/2278 [00:42<05:06,  6.50it/s, loss=0.663]
 13%|#2        | 290/2278 [00:42<05:08,  6.45it/s, loss=0.661]
 13%|#2        | 291/2278 [00:42<04:53,  6.77it/s, loss=0.656]
 13%|#2        | 292/2278 [00:43<04:54,  6.75it/s, loss=0.666]
 13%|#2        | 293/2278 [00:43<04:48,  6.88it/s, loss=0.664]
 13%|#2        | 294/2278 [00:43<04:38,  7.11it/s, loss=0.661]
 13%|#2        | 295/2278 [00:43<04:31,  7.29it/s, loss=0.660]
 13%|#2        | 296/2278 [00:43<04:41,  7.05it/s, loss=0.667]
 13%|#3        | 297/2278 [00:43<04:47,  6.89it/s, loss=0.658]
 13%|#3        | 298/2278 [00:43<04:38,  7.12it/s, loss=0.659]
 13%|#3        | 299/2278 [00:44<04:36,  7.16it/s, loss=0.656]
 13%|#3        | 300/2278 [00:44<04:44,  6.95it/s, loss=0.661]
 13%|#3        | 301/2278 [00:44<04:41,  7.03it/s, loss=0.657]
 13%|#3        | 302/2278 [00:44<04:42,  7.00it/s, loss=0.661]
 13%|#3        | 303/2278 [00:44<04:48,  6.84it/s, loss=0.655]
 13%|#3        | 304/2278 [00:44<04:50,  6.80it/s, loss=0.661]
 13%|#3        | 305/2278 [00:44<04:52,  6.74it/s, loss=0.656]
 13%|#3        | 306/2278 [00:45<04:54,  6.71it/s, loss=0.664]
 13%|#3        | 307/2278 [00:45<04:47,  6.85it/s, loss=0.656]
 14%|#3        | 308/2278 [00:45<04:46,  6.88it/s, loss=0.657]
 14%|#3        | 309/2278 [00:45<04:37,  7.10it/s, loss=0.662]
 14%|#3        | 310/2278 [00:45<04:41,  6.98it/s, loss=0.662]
 14%|#3        | 311/2278 [00:45<04:36,  7.10it/s, loss=0.662]
 14%|#3        | 312/2278 [00:45<04:35,  7.15it/s, loss=0.661]
 14%|#3        | 313/2278 [00:46<04:28,  7.31it/s, loss=0.660]
 14%|#3        | 314/2278 [00:46<04:29,  7.29it/s, loss=0.655]
 14%|#3        | 315/2278 [00:46<04:49,  6.78it/s, loss=0.654]
 14%|#3        | 316/2278 [00:46<04:42,  6.95it/s, loss=0.656]
 14%|#3        | 317/2278 [00:46<04:46,  6.85it/s, loss=0.653]
 14%|#3        | 318/2278 [00:46<04:52,  6.70it/s, loss=0.656]
 14%|#4        | 319/2278 [00:46<04:56,  6.61it/s, loss=0.656]
 14%|#4        | 320/2278 [00:47<04:53,  6.67it/s, loss=0.655]
 14%|#4        | 321/2278 [00:47<04:40,  6.97it/s, loss=0.656]
 14%|#4        | 322/2278 [00:47<04:45,  6.85it/s, loss=0.652]
 14%|#4        | 323/2278 [00:47<04:42,  6.91it/s, loss=0.661]
 14%|#4        | 324/2278 [00:47<04:55,  6.61it/s, loss=0.661]
 14%|#4        | 325/2278 [00:47<05:01,  6.47it/s, loss=0.660]
 14%|#4        | 326/2278 [00:47<04:52,  6.67it/s, loss=0.657]
 14%|#4        | 327/2278 [00:48<04:42,  6.91it/s, loss=0.659]
 14%|#4        | 328/2278 [00:48<04:49,  6.74it/s, loss=0.667]
 14%|#4        | 329/2278 [00:48<04:47,  6.78it/s, loss=0.658]
 14%|#4        | 330/2278 [00:48<04:43,  6.87it/s, loss=0.657]
 15%|#4        | 331/2278 [00:48<04:46,  6.80it/s, loss=0.659]
 15%|#4        | 332/2278 [00:48<04:41,  6.92it/s, loss=0.657]
 15%|#4        | 333/2278 [00:48<04:32,  7.13it/s, loss=0.658]
 15%|#4        | 334/2278 [00:49<04:39,  6.96it/s, loss=0.654]
 15%|#4        | 335/2278 [00:49<04:37,  6.99it/s, loss=0.653]
 15%|#4        | 336/2278 [00:49<04:48,  6.73it/s, loss=0.663]
 15%|#4        | 337/2278 [00:49<04:36,  7.01it/s, loss=0.654]
 15%|#4        | 338/2278 [00:49<04:32,  7.13it/s, loss=0.659]
 15%|#4        | 339/2278 [00:49<04:26,  7.28it/s, loss=0.659]
 15%|#4        | 340/2278 [00:49<04:40,  6.90it/s, loss=0.661]
 15%|#4        | 341/2278 [00:50<04:39,  6.94it/s, loss=0.658]
 15%|#5        | 342/2278 [00:50<04:44,  6.81it/s, loss=0.660]
 15%|#5        | 343/2278 [00:50<04:52,  6.62it/s, loss=0.652]
 15%|#5        | 344/2278 [00:50<04:56,  6.51it/s, loss=0.656]
 15%|#5        | 345/2278 [00:50<04:55,  6.54it/s, loss=0.661]
 15%|#5        | 346/2278 [00:50<04:48,  6.70it/s, loss=0.657]
 15%|#5        | 347/2278 [00:51<04:39,  6.90it/s, loss=0.659]
 15%|#5        | 348/2278 [00:51<04:35,  7.00it/s, loss=0.660]
 15%|#5        | 349/2278 [00:51<04:42,  6.83it/s, loss=0.659]
 15%|#5        | 350/2278 [00:51<04:47,  6.72it/s, loss=0.658]
 15%|#5        | 351/2278 [00:51<04:48,  6.67it/s, loss=0.658]
 15%|#5        | 352/2278 [00:51<04:42,  6.81it/s, loss=0.660]
 15%|#5        | 353/2278 [00:51<04:42,  6.80it/s, loss=0.657]
 16%|#5        | 354/2278 [00:52<04:48,  6.67it/s, loss=0.660]
 16%|#5        | 355/2278 [00:52<04:50,  6.62it/s, loss=0.654]
 16%|#5        | 356/2278 [00:52<04:45,  6.74it/s, loss=0.660]
 16%|#5        | 357/2278 [00:52<05:02,  6.34it/s, loss=0.658]
 16%|#5        | 358/2278 [00:52<04:58,  6.44it/s, loss=0.656]
 16%|#5        | 359/2278 [00:52<04:43,  6.78it/s, loss=0.649]
 16%|#5        | 360/2278 [00:52<04:33,  7.02it/s, loss=0.656]
 16%|#5        | 361/2278 [00:53<04:26,  7.20it/s, loss=0.650]
 16%|#5        | 362/2278 [00:53<04:33,  7.00it/s, loss=0.656]
 16%|#5        | 363/2278 [00:53<04:54,  6.50it/s, loss=0.657]
 16%|#5        | 364/2278 [00:53<04:58,  6.41it/s, loss=0.656]
 16%|#6        | 365/2278 [00:53<04:43,  6.75it/s, loss=0.662]
 16%|#6        | 366/2278 [00:53<04:39,  6.85it/s, loss=0.655]
 16%|#6        | 367/2278 [00:54<04:40,  6.82it/s, loss=0.659]
 16%|#6        | 368/2278 [00:54<04:44,  6.72it/s, loss=0.661]
 16%|#6        | 369/2278 [00:54<04:32,  7.00it/s, loss=0.650]
 16%|#6        | 370/2278 [00:54<04:29,  7.08it/s, loss=0.656]
 16%|#6        | 371/2278 [00:54<04:39,  6.83it/s, loss=0.656]
 16%|#6        | 372/2278 [00:54<04:37,  6.88it/s, loss=0.649]
 16%|#6        | 373/2278 [00:54<04:34,  6.94it/s, loss=0.660]
 16%|#6        | 374/2278 [00:55<04:28,  7.08it/s, loss=0.656]
 16%|#6        | 375/2278 [00:55<04:24,  7.18it/s, loss=0.655]
 17%|#6        | 376/2278 [00:55<04:27,  7.11it/s, loss=0.652]
 17%|#6        | 377/2278 [00:55<04:31,  6.99it/s, loss=0.658]
 17%|#6        | 378/2278 [00:55<04:32,  6.97it/s, loss=0.658]
 17%|#6        | 379/2278 [00:55<04:31,  7.00it/s, loss=0.659]
 17%|#6        | 380/2278 [00:55<04:22,  7.22it/s, loss=0.657]
 17%|#6        | 381/2278 [00:55<04:26,  7.12it/s, loss=0.654]
 17%|#6        | 382/2278 [00:56<04:39,  6.79it/s, loss=0.657]
 17%|#6        | 383/2278 [00:56<04:34,  6.90it/s, loss=0.658]
 17%|#6        | 384/2278 [00:56<04:26,  7.11it/s, loss=0.656]
 17%|#6        | 385/2278 [00:56<04:25,  7.12it/s, loss=0.661]
 17%|#6        | 386/2278 [00:56<04:39,  6.78it/s, loss=0.657]
 17%|#6        | 387/2278 [00:56<04:35,  6.86it/s, loss=0.656]
 17%|#7        | 388/2278 [00:57<04:33,  6.91it/s, loss=0.663]
 17%|#7        | 389/2278 [00:57<04:29,  7.02it/s, loss=0.654]
 17%|#7        | 390/2278 [00:57<04:27,  7.07it/s, loss=0.656]
 17%|#7        | 391/2278 [00:57<04:19,  7.27it/s, loss=0.657]
 17%|#7        | 392/2278 [00:57<04:14,  7.40it/s, loss=0.656]
 17%|#7        | 393/2278 [00:57<04:26,  7.06it/s, loss=0.659]
 17%|#7        | 394/2278 [00:57<04:31,  6.93it/s, loss=0.659]
 17%|#7        | 395/2278 [00:58<04:31,  6.94it/s, loss=0.659]
 17%|#7        | 396/2278 [00:58<04:23,  7.14it/s, loss=0.652]
 17%|#7        | 397/2278 [00:58<04:18,  7.26it/s, loss=0.655]
 17%|#7        | 398/2278 [00:58<04:21,  7.20it/s, loss=0.657]
 18%|#7        | 399/2278 [00:58<04:17,  7.30it/s, loss=0.651]
 18%|#7        | 400/2278 [00:58<04:26,  7.06it/s, loss=0.652]
 18%|#7        | 401/2278 [00:58<04:28,  6.98it/s, loss=0.657]
 18%|#7        | 402/2278 [00:58<04:20,  7.21it/s, loss=0.656]
 18%|#7        | 403/2278 [00:59<04:26,  7.05it/s, loss=0.656]
 18%|#7        | 404/2278 [00:59<04:30,  6.93it/s, loss=0.643]
 18%|#7        | 405/2278 [00:59<04:21,  7.16it/s, loss=0.656]
 18%|#7        | 406/2278 [00:59<04:21,  7.15it/s, loss=0.654]
 18%|#7        | 407/2278 [00:59<04:29,  6.93it/s, loss=0.659]
 18%|#7        | 408/2278 [00:59<04:24,  7.07it/s, loss=0.652]
 18%|#7        | 409/2278 [00:59<04:32,  6.86it/s, loss=0.653]
 18%|#7        | 410/2278 [01:00<04:30,  6.90it/s, loss=0.656]
 18%|#8        | 411/2278 [01:00<04:29,  6.92it/s, loss=0.653]
 18%|#8        | 412/2278 [01:00<04:27,  6.98it/s, loss=0.656]
 18%|#8        | 413/2278 [01:00<04:31,  6.88it/s, loss=0.657]
 18%|#8        | 414/2278 [01:00<04:23,  7.07it/s, loss=0.654]
 18%|#8        | 415/2278 [01:00<04:23,  7.06it/s, loss=0.654]
 18%|#8        | 416/2278 [01:00<04:27,  6.95it/s, loss=0.654]
 18%|#8        | 417/2278 [01:01<04:25,  7.01it/s, loss=0.653]
 18%|#8        | 418/2278 [01:01<04:24,  7.03it/s, loss=0.657]
 18%|#8        | 419/2278 [01:01<04:22,  7.08it/s, loss=0.653]
 18%|#8        | 420/2278 [01:01<04:22,  7.08it/s, loss=0.653]
 18%|#8        | 421/2278 [01:01<04:22,  7.07it/s, loss=0.655]
 19%|#8        | 422/2278 [01:01<04:34,  6.75it/s, loss=0.655]
 19%|#8        | 423/2278 [01:01<04:32,  6.81it/s, loss=0.654]
 19%|#8        | 424/2278 [01:02<04:35,  6.74it/s, loss=0.654]
 19%|#8        | 425/2278 [01:02<04:41,  6.58it/s, loss=0.648]
 19%|#8        | 426/2278 [01:02<04:35,  6.71it/s, loss=0.651]
 19%|#8        | 427/2278 [01:02<04:30,  6.83it/s, loss=0.652]
 19%|#8        | 428/2278 [01:02<04:29,  6.87it/s, loss=0.656]
 19%|#8        | 429/2278 [01:02<04:27,  6.92it/s, loss=0.650]
 19%|#8        | 430/2278 [01:03<04:34,  6.74it/s, loss=0.653]
 19%|#8        | 431/2278 [01:03<04:28,  6.87it/s, loss=0.655]
 19%|#8        | 432/2278 [01:03<04:27,  6.90it/s, loss=0.654]
 19%|#9        | 433/2278 [01:03<04:40,  6.57it/s, loss=0.648]
 19%|#9        | 434/2278 [01:03<04:37,  6.64it/s, loss=0.652]
 19%|#9        | 435/2278 [01:03<04:37,  6.65it/s, loss=0.653]
 19%|#9        | 436/2278 [01:03<04:31,  6.78it/s, loss=0.652]
 19%|#9        | 437/2278 [01:04<04:20,  7.06it/s, loss=0.656]
 19%|#9        | 438/2278 [01:04<04:19,  7.09it/s, loss=0.650]
 19%|#9        | 439/2278 [01:04<04:18,  7.11it/s, loss=0.647]
 19%|#9        | 440/2278 [01:04<04:16,  7.17it/s, loss=0.652]
 19%|#9        | 441/2278 [01:04<04:11,  7.29it/s, loss=0.657]
 19%|#9        | 442/2278 [01:04<04:13,  7.23it/s, loss=0.658]
 19%|#9        | 443/2278 [01:04<04:15,  7.18it/s, loss=0.655]
 19%|#9        | 444/2278 [01:05<04:21,  7.00it/s, loss=0.655]
 20%|#9        | 445/2278 [01:05<04:30,  6.77it/s, loss=0.653]
 20%|#9        | 446/2278 [01:05<04:33,  6.71it/s, loss=0.648]
 20%|#9        | 447/2278 [01:05<04:22,  6.98it/s, loss=0.653]
 20%|#9        | 448/2278 [01:05<04:20,  7.02it/s, loss=0.660]
 20%|#9        | 449/2278 [01:05<04:13,  7.22it/s, loss=0.648]
 20%|#9        | 450/2278 [01:05<04:12,  7.23it/s, loss=0.652]
 20%|#9        | 451/2278 [01:06<04:22,  6.95it/s, loss=0.649]
 20%|#9        | 452/2278 [01:06<04:20,  7.01it/s, loss=0.655]
 20%|#9        | 453/2278 [01:06<04:15,  7.15it/s, loss=0.654]
 20%|#9        | 454/2278 [01:06<04:21,  6.98it/s, loss=0.659]
 20%|#9        | 455/2278 [01:06<04:20,  7.00it/s, loss=0.654]
 20%|##        | 456/2278 [01:06<04:33,  6.67it/s, loss=0.648]
 20%|##        | 457/2278 [01:06<04:21,  6.96it/s, loss=0.651]
 20%|##        | 458/2278 [01:07<04:19,  7.02it/s, loss=0.654]
 20%|##        | 459/2278 [01:07<04:21,  6.96it/s, loss=0.655]
 20%|##        | 460/2278 [01:07<04:24,  6.87it/s, loss=0.654]
 20%|##        | 461/2278 [01:07<04:20,  6.98it/s, loss=0.659]
 20%|##        | 462/2278 [01:07<04:26,  6.81it/s, loss=0.652]
 20%|##        | 463/2278 [01:07<04:38,  6.53it/s, loss=0.649]
 20%|##        | 464/2278 [01:07<04:40,  6.48it/s, loss=0.657]
 20%|##        | 465/2278 [01:08<04:33,  6.63it/s, loss=0.658]
 20%|##        | 466/2278 [01:08<04:35,  6.58it/s, loss=0.653]
 21%|##        | 467/2278 [01:08<04:22,  6.90it/s, loss=0.651]
 21%|##        | 468/2278 [01:08<04:36,  6.55it/s, loss=0.649]
 21%|##        | 469/2278 [01:08<04:29,  6.71it/s, loss=0.652]
 21%|##        | 470/2278 [01:08<04:17,  7.01it/s, loss=0.650]
 21%|##        | 471/2278 [01:08<04:10,  7.22it/s, loss=0.646]
 21%|##        | 472/2278 [01:09<04:14,  7.09it/s, loss=0.651]
 21%|##        | 473/2278 [01:09<04:20,  6.92it/s, loss=0.653]
 21%|##        | 474/2278 [01:09<04:12,  7.13it/s, loss=0.647]
 21%|##        | 475/2278 [01:09<04:21,  6.90it/s, loss=0.659]
 21%|##        | 476/2278 [01:09<04:25,  6.78it/s, loss=0.653]
 21%|##        | 477/2278 [01:09<04:20,  6.91it/s, loss=0.650]
 21%|##        | 478/2278 [01:09<04:27,  6.74it/s, loss=0.655]
 21%|##1       | 479/2278 [01:10<04:28,  6.69it/s, loss=0.656]
 21%|##1       | 480/2278 [01:10<04:17,  6.97it/s, loss=0.648]
 21%|##1       | 481/2278 [01:10<04:16,  7.01it/s, loss=0.653]
 21%|##1       | 482/2278 [01:10<04:16,  7.01it/s, loss=0.654]
 21%|##1       | 483/2278 [01:10<04:15,  7.02it/s, loss=0.655]
 21%|##1       | 484/2278 [01:10<04:21,  6.86it/s, loss=0.653]
 21%|##1       | 485/2278 [01:10<04:19,  6.92it/s, loss=0.652]
 21%|##1       | 486/2278 [01:11<04:27,  6.70it/s, loss=0.658]
 21%|##1       | 487/2278 [01:11<04:28,  6.67it/s, loss=0.660]
 21%|##1       | 488/2278 [01:11<04:23,  6.78it/s, loss=0.654]
 21%|##1       | 489/2278 [01:11<04:27,  6.69it/s, loss=0.653]
 22%|##1       | 490/2278 [01:11<04:23,  6.78it/s, loss=0.653]
 22%|##1       | 491/2278 [01:11<04:20,  6.87it/s, loss=0.654]
 22%|##1       | 492/2278 [01:12<04:17,  6.93it/s, loss=0.652]
 22%|##1       | 493/2278 [01:12<04:23,  6.79it/s, loss=0.653]
 22%|##1       | 494/2278 [01:12<04:19,  6.87it/s, loss=0.652]
 22%|##1       | 495/2278 [01:12<04:28,  6.64it/s, loss=0.653]
 22%|##1       | 496/2278 [01:12<04:18,  6.90it/s, loss=0.648]
 22%|##1       | 497/2278 [01:12<04:14,  6.99it/s, loss=0.656]
 22%|##1       | 498/2278 [01:12<04:14,  6.99it/s, loss=0.662]
 22%|##1       | 499/2278 [01:13<04:14,  7.00it/s, loss=0.653]Converges at iteration 10
Epoch 0 Validation Accuracy 0.07657975099835565 Test Accuracy 0.059091002613007426

 22%|##1       | 499/2278 [01:31<05:27,  5.43it/s, loss=0.651]

Evaluating Performance with Link Prediction (Optional)¶

In practice, it is more common to evaluate the link prediction model to see whether it can predict new edges. There are different evaluation metrics such as AUC or various metrics from information retrieval. Ultimately, they require the model to predict one scalar score given a node pair among a set of node pairs.

Assuming that you have the following test set with labels, where test_pos_src and test_pos_dst are ground truth node pairs with edges in between (or positive pairs), and test_neg_src and test_neg_dst are ground truth node pairs without edges in between (or negative pairs).

# Positive pairs
# These are randomly generated as an example.  You will need to
# replace them with your own ground truth.
n_test_pos = 1000
test_pos_src, test_pos_dst = (
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
    torch.randint(0, graph.num_nodes(), (n_test_pos,)),
)
# Negative pairs.  Likewise, you will need to replace them with your
# own ground truth.
test_neg_src = test_pos_src
test_neg_dst = torch.randint(0, graph.num_nodes(), (n_test_pos,))

First you need to compute the node representations for all the nodes with the inference method above:

node_reprs = inference(model, graph, node_features)

Since the predictor is a dot product, you can now easily compute the score of positive and negative test pairs to compute metrics such as AUC:

h_pos_src = node_reprs[test_pos_src]
h_pos_dst = node_reprs[test_pos_dst]
h_neg_src = node_reprs[test_neg_src]
h_neg_dst = node_reprs[test_neg_dst]
score_pos = (h_pos_src * h_pos_dst).sum(1)
score_neg = (h_neg_src * h_neg_dst).sum(1)
test_preds = torch.cat([score_pos, score_neg]).cpu().numpy()
test_labels = (
    torch.cat([torch.ones_like(score_pos), torch.zeros_like(score_neg)])
    .cpu()
    .numpy()
)

auc = sklearn.metrics.roc_auc_score(test_labels, test_preds)
print("Link Prediction AUC:", auc)

Out:

Link Prediction AUC: 0.48600099999999996

Conclusion¶

In this tutorial, you have learned how to train a multi-layer GraphSAGE for link prediction with neighbor sampling.

# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'

Total running time of the script: ( 1 minutes 36.670 seconds)

Download Python source code: L2_large_link_prediction.py

Download Jupyter notebook: L2_large_link_prediction.ipynb

Gallery generated by Sphinx-Gallery

Previous Next

© Copyright 2018, DGL Team. Revision d9da4205.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 1.0.x
Versions
latest
1.0.x
0.9.x
0.8.x
0.7.x
0.6.x
Downloads
On Read the Docs
Project Home
Builds