"""Predictor for edges in homogeneous graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import torch
import torch.nn as nn
import torch.nn.functional as F

[docs]class EdgePredictor(nn.Module):
r"""Predictor/score function for pairs of node representations

Given a pair of node representations, :math:h_i and :math:h_j, it combines them with

**dot product**

.. math::

h_i^{T} h_j

or **cosine similarity**

.. math::

\frac{h_i^{T} h_j}{{\| h_i \|}_2 \cdot {\| h_j \|}_2}

or **elementwise product**

.. math::

h_i \odot h_j

or **concatenation**

.. math::

h_i \Vert h_j

Optionally, it passes the combined results to a linear layer for the final prediction.

Parameters
----------
op : str
The operation to apply. It can be 'dot', 'cos', 'ele', or 'cat',
corresponding to the equations above in order.
in_feats : int, optional
The input feature size of :math:h_i and :math:h_j. It is required
only if a linear layer is to be applied.
out_feats : int, optional
The output feature size. It is reuiqred only if a linear layer is to be applied.
bias : bool, optional
Whether to use bias for the linear layer if it applies.

Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import EdgePredictor
>>> num_nodes = 2
>>> num_edges = 3
>>> in_feats = 4
>>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)
>>> h = th.randn(num_nodes, in_feats)
>>> src, dst = g.edges()
>>> h_src = h[src]
>>> h_dst = h[dst]

Case1: dot product

>>> predictor = EdgePredictor('dot')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 1])
>>> predictor = EdgePredictor('dot', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])

Case2: cosine similarity

>>> predictor = EdgePredictor('cos')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 1])
>>> predictor = EdgePredictor('cos', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])

Case3: elementwise product

>>> predictor = EdgePredictor('ele')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 4])
>>> predictor = EdgePredictor('ele', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])

Case4: concatenation

>>> predictor = EdgePredictor('cat')
>>> predictor(h_src, h_dst).shape
torch.Size([3, 8])
>>> predictor = EdgePredictor('cat', in_feats, out_feats=3)
>>> predictor.reset_parameters()
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
"""

def __init__(self, op, in_feats=None, out_feats=None, bias=False):
super(EdgePredictor, self).__init__()

assert op in [
"dot",
"cos",
"ele",
"cat",
], "Expect op to be in ['dot', 'cos', 'ele', 'cat'], got {}".format(op)
self.op = op
if (in_feats is not None) and (out_feats is not None):
if op in ["dot", "cos"]:
in_feats = 1
elif op == "cat":
in_feats = 2 * in_feats
self.linear = nn.Linear(in_feats, out_feats, bias=bias)
else:
self.linear = None

[docs]    def reset_parameters(self):
r"""

Description
-----------
Reinitialize learnable parameters.
"""
if self.linear is not None:
self.linear.reset_parameters()

[docs]    def forward(self, h_src, h_dst):
r"""

Description
-----------
Predict for pairs of node representations.

Parameters
----------
h_src : torch.Tensor
Source node features. The tensor is of shape :math:(E, D_{in}),
where :math:E is the number of edges/node pairs, and :math:D_{in}
is the input feature size.
h_dst : torch.Tensor
Destination node features. The tensor is of shape :math:(E, D_{in}),
where :math:E is the number of edges/node pairs, and :math:D_{in}
is the input feature size.

Returns
-------
torch.Tensor
The output features.
"""
if self.op == "dot":
N, D = h_src.shape
h = torch.bmm(h_src.view(N, 1, D), h_dst.view(N, D, 1)).squeeze(-1)
elif self.op == "cos":
h = F.cosine_similarity(h_src, h_dst).unsqueeze(-1)
elif self.op == "ele":
h = h_src * h_dst
else:
h = torch.cat([h_src, h_dst], dim=-1)

if self.linear is not None:
h = self.linear(h)

return h