Source code for dgl.nn.pytorch.conv.hgtconv

"""Heterogeneous Graph Transformer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import torch
import torch.nn as nn

from .... import function as fn
from ..linear import TypedLinear
from ..softmax import edge_softmax

[docs]class HGTConv(nn.Module): r"""Heterogeneous graph transformer convolution from `Heterogeneous Graph Transformer <https://arxiv.org/abs/2003.01332>`__ Given a graph :math:`G(V, E)` and input node features :math:`H^{(l-1)}`, it computes the new node features as follows: Compute a multi-head attention score for each edge :math:`(s, e, t)` in the graph: .. math:: Attention(s, e, t) = \text{Softmax}\left(||_{i\in[1,h]}ATT-head^i(s, e, t)\right) \\ ATT-head^i(s, e, t) = \left(K^i(s)W^{ATT}_{\phi(e)}Q^i(t)^{\top}\right)\cdot \frac{\mu_{(\tau(s),\phi(e),\tau(t)}}{\sqrt{d}} \\ K^i(s) = \text{K-Linear}^i_{\tau(s)}(H^{(l-1)}[s]) \\ Q^i(t) = \text{Q-Linear}^i_{\tau(t)}(H^{(l-1)}[t]) \\ Compute the message to send on each edge :math:`(s, e, t)`: .. math:: Message(s, e, t) = ||_{i\in[1, h]} MSG-head^i(s, e, t) \\ MSG-head^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s])W^{MSG}_{\phi(e)} \\ Send messages to target nodes :math:`t` and aggregate: .. math:: \tilde{H}^{(l)}[t] = \sum_{\forall s\in \mathcal{N}(t)}\left( Attention(s,e,t) \cdot Message(s,e,t)\right) Compute new node features: .. math:: H^{(l)}[t]=\text{A-Linear}_{\tau(t)}(\sigma(\tilde(H)^{(l)}[t])) + H^{(l-1)}[t] Parameters ---------- in_size : int Input node feature size. head_size : int Output head size. The output node feature size is ``head_size * num_heads``. num_heads : int Number of heads. The output node feature size is ``head_size * num_heads``. num_ntypes : int Number of node types. num_etypes : int Number of edge types. dropout : optional, float Dropout rate. use_norm : optiona, bool If true, apply a layer norm on the output node feature. Examples -------- """ def __init__(self, in_size, head_size, num_heads, num_ntypes, num_etypes, dropout=0.2, use_norm=False): super().__init__() self.in_size = in_size self.head_size = head_size self.num_heads = num_heads self.sqrt_d = math.sqrt(head_size) self.use_norm = use_norm self.linear_k = TypedLinear(in_size, head_size * num_heads, num_ntypes) self.linear_q = TypedLinear(in_size, head_size * num_heads, num_ntypes) self.linear_v = TypedLinear(in_size, head_size * num_heads, num_ntypes) self.linear_a = TypedLinear(head_size * num_heads, head_size * num_heads, num_ntypes) self.relation_pri = nn.ParameterList([nn.Parameter(torch.ones(num_etypes)) for i in range(num_heads)]) self.relation_att = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes) for i in range(num_heads)]) self.relation_msg = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes) for i in range(num_heads)]) self.skip = nn.Parameter(torch.ones(num_ntypes)) self.drop = nn.Dropout(dropout) if use_norm: self.norm = nn.LayerNorm(head_size * num_heads) if in_size != head_size * num_heads: self.residual_w = nn.Parameter(torch.Tensor(in_size, head_size * num_heads)) nn.init.xavier_uniform_(self.residual_w)
[docs] def forward(self, g, x, ntype, etype, *, presorted=False): """Forward computation. Parameters ---------- g : DGLGraph The input graph. x : torch.Tensor A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`. ntype : torch.Tensor An 1D integer tensor of node types. Shape: :math:`(|V|,)`. etype : torch.Tensor An 1D integer tensor of edge types. Shape: :math:`(|E|,)`. presorted : bool, optional Whether *both* the nodes and the edges of the input graph have been sorted by their types. Forward on pre-sorted graph may be faster. Graphs created by :func:`~dgl.to_homogeneous` automatically satisfy the condition. Also see :func:`~dgl.reorder_graph` for manually reordering the nodes and edges. Returns ------- torch.Tensor New node features. Shape: :math:`(|V|, D_{head} * N_{head})`. """ self.presorted = presorted with g.local_scope(): k = self.linear_k(x, ntype, presorted).view(-1, self.num_heads, self.head_size) q = self.linear_q(x, ntype, presorted).view(-1, self.num_heads, self.head_size) v = self.linear_v(x, ntype, presorted).view(-1, self.num_heads, self.head_size) g.srcdata['k'] = k g.dstdata['q'] = q g.srcdata['v'] = v g.edata['etype'] = etype g.apply_edges(self.message) g.edata['m'] = g.edata['m'] * edge_softmax(g, g.edata['a']).unsqueeze(-1) g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'h')) h = g.dstdata['h'].view(-1, self.num_heads * self.head_size) # target-specific aggregation h = self.drop(self.linear_a(h, ntype, presorted)) alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1) if x.shape != h.shape: h = h * alpha + (x @ self.residual_w) * (1 - alpha) else: h = h * alpha + x * (1 - alpha) if self.use_norm: h = self.norm(h) return h
def message(self, edges): """Message function.""" a, m = [], [] etype = edges.data['etype'] k = torch.unbind(edges.src['k'], dim=1) q = torch.unbind(edges.dst['q'], dim=1) v = torch.unbind(edges.src['v'], dim=1) for i in range(self.num_heads): kw = self.relation_att[i](k[i], etype, self.presorted) # (E, O) a.append((kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d) # (E,) m.append(self.relation_msg[i](v[i], etype, self.presorted)) # (E, O) return {'a' : torch.stack(a, dim=1), 'm' : torch.stack(m, dim=1)}