8์ฅ: Mixed Precision ํ์ต๏
DGL์ mixed precision ํ์ต์ ์ํด์ PyTorchโs automatic mixed precision package ์ ํธํ๋๋ค. ๋ฐ๋ผ์, ํ์ต ์๊ฐ ๋ฐ GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ ์ฝํ ์ ์๋ค.
Half precision์ ์ฌ์ฉํ ๋ฉ์์ง ์ ๋ฌ๏
fp16์ ์ง์ํ๋ DGL์ UDF(User Defined Function)์ด๋ ๋นํธ์ธ ํจ์(์, dgl.function.sum
,
dgl.function.copy_u
)๋ฅผ ์ฌ์ฉํด์ float16
ํผ์ณ์ ๋ํ ๋ฉ์์ง ์ ๋ฌ์ ํ์ฉํ๋ค.
๋ค์ ์์ ๋ DGL ๋ฉ์์ง ์ ๋ฌ API๋ฅผ half-precision ํผ์ณ๋ค์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋ค.
>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> g = dgl.rand_graph(30, 100).to(0) # Create a graph on GPU w/ 30 nodes and 100 edges.
>>> g.ndata['h'] = torch.rand(30, 16).to(0).half() # Create fp16 node features.
>>> g.edata['w'] = torch.rand(100, 1).to(0).half() # Create fp16 edge features.
>>> # Use DGL's built-in functions for message passing on fp16 features.
>>> g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'x'))
>>> g.ndata['x'][0]
tensor([0.3391, 0.2208, 0.7163, 0.6655, 0.7031, 0.5854, 0.9404, 0.7720, 0.6562,
0.4028, 0.6943, 0.5908, 0.9307, 0.5962, 0.7827, 0.5034],
device='cuda:0', dtype=torch.float16)
>>> g.apply_edges(fn.u_dot_v('h', 'x', 'hx'))
>>> g.edata['hx'][0]
tensor([5.4570], device='cuda:0', dtype=torch.float16)
>>> # Use UDF(User Defined Functions) for message passing on fp16 features.
>>> def message(edges):
... return {'m': edges.src['h'] * edges.data['w']}
...
>>> def reduce(nodes):
... return {'y': torch.sum(nodes.mailbox['m'], 1)}
...
>>> def dot(edges):
... return {'hy': (edges.src['h'] * edges.dst['y']).sum(-1, keepdims=True)}
...
>>> g.update_all(message, reduce)
>>> g.ndata['y'][0]
tensor([0.3394, 0.2209, 0.7168, 0.6655, 0.7026, 0.5854, 0.9404, 0.7720, 0.6562,
0.4028, 0.6943, 0.5908, 0.9307, 0.5967, 0.7827, 0.5039],
device='cuda:0', dtype=torch.float16)
>>> g.apply_edges(dot)
>>> g.edata['hy'][0]
tensor([5.4609], device='cuda:0', dtype=torch.float16)
End-to-End Mixed Precision ํ์ต๏
DGL์ PyTorch์ AMP package๋ฅผ ์ฌ์ฉํด์ mixed precision ํ์ต์ ๊ตฌํํ๊ณ ์์ด์, ์ฌ์ฉ ๋ฐฉ๋ฒ์ PyTorch์ ๊ฒ ๊ณผ ๋์ผํ๋ค.
GNN ๋ชจ๋ธ์ forward ํจ์ค(loss ๊ณ์ฐ ํฌํจ)๋ฅผ torch.cuda.amp.autocast()
๋ก ๋ํํ๋ฉด PyTorch๋ ๊ฐ op ๋ฐ ํ
์์ ๋ํด์ ์ ์ ํ ๋ฐ์ดํฐ ํ์
์ ์๋์ผ๋ก ์ ํํ๋ค. Half precision ํ
์๋ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ด๊ณ , half precision ํ
์์ ๋ํ ๋๋ถ๋ถ ์ฐ์ฐ๋ค์ GPU tensorcore๋ค์ ํ์ฉํ๊ธฐ ๋๋ฌธ์ ๋ ๋น ๋ฅด๋ค.
float16
ํฌ๋ฉง์ ์์ graident๋ค์ ์ธ๋ํ๋ก์ฐ(underflow) ๋ฌธ์ ๋ฅผ ๊ฐ๋๋ฐ (0์ด ๋๋ฒ๋ฆผ), PyTorch๋ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์ GradScaler
๋ชจ๋์ ์ ๊ณตํ๋ค. GradScaler
๋ loss ๊ฐ์ factor๋ฅผ ๊ณฑํ๊ณ , ์ด scaled loss์ backward pass๋ฅผ ์ํํ๋ค. ๊ทธ๋ฆฌ๊ณ ํ๋ผ๋ฉํฐ๋ค์ ์
๋ฐ์ดํธํ๋ optimizer๋ฅผ ์ํํ๊ธฐ ์ ์ unscale ํ๋ค.
๋ค์์ 3-๋ ์ด์ด GAT๋ฅผ Reddit ๋ฐ์ดํฐ์
(1140์ต๊ฐ์ ์์ง๋ฅผ ๊ฐ๋)์ ํ์ต์ ํ๋ ์คํฌ๋ฆฝํธ์ด๋ค. use_fp16
๊ฐ ํ์ฑํ/๋นํ์ฑํ๋์์ ๋์ ์ฝ๋ ์ฐจ์ด๋ฅผ ์ดํด๋ณด์.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import dgl
from dgl.data import RedditDataset
from dgl.nn import GATConv
use_fp16 = True
class GAT(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
heads):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(GATConv(in_feats, n_hidden, heads[0], activation=F.elu))
self.layers.append(GATConv(n_hidden * heads[0], n_hidden, heads[1], activation=F.elu))
self.layers.append(GATConv(n_hidden * heads[1], n_classes, heads[2], activation=F.elu))
def forward(self, g, h):
for l, layer in enumerate(self.layers):
h = layer(g, h)
if l != len(self.layers) - 1:
h = h.flatten(1)
else:
h = h.mean(1)
return h
# Data loading
data = RedditDataset()
device = torch.device(0)
g = data[0]
g = dgl.add_self_loop(g)
g = g.int().to(device)
train_mask = g.ndata['train_mask']
features = g.ndata['feat']
labels = g.ndata['label']
in_feats = features.shape[1]
n_hidden = 256
n_classes = data.num_classes
n_edges = g.num_edges()
heads = [1, 1, 1]
model = GAT(in_feats, n_hidden, n_classes, heads)
model = model.to(device)
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
# Create gradient scaler
scaler = GradScaler()
for epoch in range(100):
model.train()
optimizer.zero_grad()
# Wrap forward pass with autocast
with autocast(enabled=use_fp16):
logits = model(g, features)
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
if use_fp16:
# Backprop w/ gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
print('Epoch {} | Loss {}'.format(epoch, loss.item()))
NVIDIA V100 (16GB) ํ๊ฐ๋ฅผ ๊ฐ๋ ์ปดํจํฐ์์, ์ด ๋ชจ๋ธ์ fp16์ ์ฌ์ฉํ์ง ์๊ณ ํ์ตํ ๋๋ 15.2GB GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ฌ์ฉ๋๋๋ฐ, fp16์ ํ์ฑํํ๋ฉด, ํ์ต์ 12.8G GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ฌ์ฉ๋๋ฉฐ, ๋ ๊ฒฝ์ฐ loss๊ฐ ๋น์ทํ ๊ฐ์ผ๋ก ์๋ ดํ๋ค. ๋ง์ฝ head์ ๊ฐฏ์๋ฅผ [2, 2, 2]
๋ก ๋ฐ๊พธ๋ฉด, fp16๋ฅผ ์ฌ์ฉํ์ง ์๋ ํ์ต์ GPU OOM(out-of-memory) ์ด์๊ฐ ์๊ธธ ๊ฒ์ด์ง๋ง, fp16๋ฅผ ์ฌ์ฉํ ํ์ต์ 15.7G GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉด์ ์ํ๋๋ค.
DGL์ half-precision ์ง์์ ๊ณ์ ํฅ์ํ๊ณ ์๊ณ , ์ฐ์ฐ ์ปค๋์ ์ฑ๋ฅ์ ์์ง ์ต์ ์ ์๋๋ค. ์์ผ๋ก์ ์ ๋ฐ์ดํธ๋ฅผ ๊ณ์ ์ง์ผ๋ณด์.