# Source code for dgl.nn.pytorch.conv.hgtconv

"""Heterogeneous Graph Transformer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import torch
import torch.nn as nn

from .... import function as fn
from ..linear import TypedLinear
from ..softmax import edge_softmax

[docs]class HGTConv(nn.Module):
r"""Heterogeneous graph transformer convolution from Heterogeneous Graph Transformer
<https://arxiv.org/abs/2003.01332>__

Given a graph :math:G(V, E) and input node features :math:H^{(l-1)},
it computes the new node features as follows:

Compute a multi-head attention score for each edge :math:(s, e, t) in the graph:

.. math::

Attention(s, e, t) = \text{Softmax}\left(||_{i\in[1,h]}ATT-head^i(s, e, t)\right) \\
\frac{\mu_{(\tau(s),\phi(e),\tau(t)}}{\sqrt{d}} \\
K^i(s) = \text{K-Linear}^i_{\tau(s)}(H^{(l-1)}[s]) \\
Q^i(t) = \text{Q-Linear}^i_{\tau(t)}(H^{(l-1)}[t]) \\

Compute the message to send on each edge :math:(s, e, t):

.. math::

Message(s, e, t) = ||_{i\in[1, h]} MSG-head^i(s, e, t) \\
MSG-head^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s])W^{MSG}_{\phi(e)} \\

Send messages to target nodes :math:t and aggregate:

.. math::

\tilde{H}^{(l)}[t] = \sum_{\forall s\in \mathcal{N}(t)}\left( Attention(s,e,t)
\cdot Message(s,e,t)\right)

Compute new node features:

.. math::

H^{(l)}[t]=\text{A-Linear}_{\tau(t)}(\sigma(\tilde(H)^{(l)}[t])) + H^{(l-1)}[t]

Parameters
----------
in_size : int
Input node feature size.
Output head size. The output node feature size is head_size * num_heads.
Number of heads. The output node feature size is head_size * num_heads.
num_ntypes : int
Number of node types.
num_etypes : int
Number of edge types.
dropout : optional, float
Dropout rate.
use_norm : optiona, bool
If true, apply a layer norm on the output node feature.

Examples
--------
"""
def __init__(self,
in_size,
num_ntypes,
num_etypes,
dropout=0.2,
use_norm=False):
super().__init__()
self.in_size = in_size
self.use_norm = use_norm

self.relation_pri = nn.ParameterList([nn.Parameter(torch.ones(num_etypes))
self.skip = nn.Parameter(torch.ones(num_ntypes))
self.drop = nn.Dropout(dropout)
if use_norm:
nn.init.xavier_uniform_(self.residual_w)

[docs]    def forward(self, g, x, ntype, etype, *, presorted=False):
"""Forward computation.

Parameters
----------
g : DGLGraph
The input graph.
x : torch.Tensor
A 2D tensor of node features. Shape: :math:(|V|, D_{in}).
ntype : torch.Tensor
An 1D integer tensor of node types. Shape: :math:(|V|,).
etype : torch.Tensor
An 1D integer tensor of edge types. Shape: :math:(|E|,).
presorted : bool, optional
Whether *both* the nodes and the edges of the input graph have been sorted by
their types. Forward on pre-sorted graph may be faster. Graphs created by
:func:~dgl.to_homogeneous automatically satisfy the condition.
Also see :func:~dgl.reorder_graph for manually reordering the nodes and edges.

Returns
-------
torch.Tensor
New node features. Shape: :math:(|V|, D_{head} * N_{head}).
"""
self.presorted = presorted
with g.local_scope():
g.srcdata['k'] = k
g.dstdata['q'] = q
g.srcdata['v'] = v
g.edata['etype'] = etype
g.apply_edges(self.message)
g.edata['m'] = g.edata['m'] * edge_softmax(g, g.edata['a']).unsqueeze(-1)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'h'))
# target-specific aggregation
h = self.drop(self.linear_a(h, ntype, presorted))
alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1)
if x.shape != h.shape:
h = h * alpha + (x @ self.residual_w) * (1 - alpha)
else:
h = h * alpha + x * (1 - alpha)
if self.use_norm:
h = self.norm(h)
return h

def message(self, edges):
"""Message function."""
a, m = [], []
etype = edges.data['etype']
k = torch.unbind(edges.src['k'], dim=1)
q = torch.unbind(edges.dst['q'], dim=1)
v = torch.unbind(edges.src['v'], dim=1)