##
# Copyright 2019-2021 Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Modules for transform"""
# pylint: disable= no-member, arguments-differ, invalid-name, missing-function-docstring
from scipy.linalg import expm
from .. import convert
from .. import backend as F
from .. import function as fn
from ..base import DGLError
from . import functional
try:
import torch
from torch.distributions import Bernoulli
except ImportError:
pass
__all__ = [
'BaseTransform',
'RowFeatNormalizer',
'FeatMask',
'RandomWalkPE',
'LaplacianPE',
'AddSelfLoop',
'RemoveSelfLoop',
'AddReverse',
'ToSimple',
'LineGraph',
'KHopGraph',
'AddMetaPaths',
'Compose',
'GCNNorm',
'PPR',
'HeatKernel',
'GDC',
'NodeShuffle',
'DropNode',
'DropEdge',
'AddEdge',
'SIGNDiffusion'
]
def update_graph_structure(g, data_dict, copy_edata=True):
r"""Update the structure of a graph.
Parameters
----------
g : DGLGraph
The graph to update.
data_dict : graph data
The dictionary data for constructing a heterogeneous graph.
copy_edata : bool
If True, it will copy the edge features to the updated graph.
Returns
-------
DGLGraph
The updated graph.
"""
device = g.device
idtype = g.idtype
num_nodes_dict = dict()
for ntype in g.ntypes:
num_nodes_dict[ntype] = g.num_nodes(ntype)
new_g = convert.heterograph(data_dict, num_nodes_dict=num_nodes_dict,
idtype=idtype, device=device)
# Copy features
for ntype in g.ntypes:
for key, feat in g.nodes[ntype].data.items():
new_g.nodes[ntype].data[key] = feat
if copy_edata:
for c_etype in g.canonical_etypes:
for key, feat in g.edges[c_etype].data.items():
new_g.edges[c_etype].data[key] = feat
return new_g
[docs]class RowFeatNormalizer(BaseTransform):
r"""
Row-normalizes the features given in ``node_feat_names`` and ``edge_feat_names``.
The row normalization formular is:
.. math::
x = \frac{x}{\sum_i x_i}
where :math:`x` denotes a row of the feature tensor.
Parameters
----------
subtract_min: bool
If True, the minimum value of whole feature tensor will be subtracted before normalization.
Default: False.
Subtraction will make all values non-negative. If all values are negative, after
normalisation, the sum of each row of the feature tensor will be 1.
node_feat_names : list[str], optional
The names of the node feature tensors to be row-normalized. Default: `None`, which will
not normalize any node feature tensor.
edge_feat_names : list[str], optional
The names of the edge feature tensors to be row-normalized. Default: `None`, which will
not normalize any edge feature tensor.
Example
-------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> from dgl import RowFeatNormalizer
Case1: Row normalize features of a homogeneous graph.
>>> transform = RowFeatNormalizer(subtract_min=True,
... node_feat_names=['h'], edge_feat_names=['w'])
>>> g = dgl.rand_graph(5, 20)
>>> g.ndata['h'] = torch.randn((g.num_nodes(), 5))
>>> g.edata['w'] = torch.randn((g.num_edges(), 5))
>>> g = transform(g)
>>> print(g.ndata['h'].sum(1))
tensor([1., 1., 1., 1., 1.])
>>> print(g.edata['w'].sum(1))
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.])
Case2: Row normalize features of a heterogeneous graph.
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),
... ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))
... })
>>> g.ndata['h'] = {'game': torch.randn(2, 5), 'player': torch.randn(3, 5)}
>>> g.edata['w'] = {
... ('user', 'follows', 'user'): torch.randn(2, 5),
... ('player', 'plays', 'game'): torch.randn(2, 5)
... }
>>> g = transform(g)
>>> print(g.ndata['h']['game'].sum(1), g.ndata['h']['player'].sum(1))
tensor([1., 1.]) tensor([1., 1., 1.])
>>> print(g.edata['w'][('user', 'follows', 'user')].sum(1),
... g.edata['w'][('player', 'plays', 'game')].sum(1))
tensor([1., 1.]) tensor([1., 1.])
"""
def __init__(self, subtract_min=False, node_feat_names=None, edge_feat_names=None):
self.node_feat_names = [] if node_feat_names is None else node_feat_names
self.edge_feat_names = [] if edge_feat_names is None else edge_feat_names
self.subtract_min = subtract_min
def row_normalize(self, feat):
r"""
Description
-----------
Row-normalize the given feature.
Parameters
----------
feat : Tensor
The feature to be normalized.
Returns
-------
Tensor
The normalized feature.
"""
if self.subtract_min:
feat = feat - feat.min()
feat.div_(feat.sum(dim=-1, keepdim=True).clamp_(min=1.))
return feat
def __call__(self, g):
for node_feat_name in self.node_feat_names:
if isinstance(g.ndata[node_feat_name], torch.Tensor):
g.ndata[node_feat_name] = self.row_normalize(g.ndata[node_feat_name])
else:
for ntype in g.ndata[node_feat_name].keys():
g.nodes[ntype].data[node_feat_name] = \
self.row_normalize(g.nodes[ntype].data[node_feat_name])
for edge_feat_name in self.edge_feat_names:
if isinstance(g.edata[edge_feat_name], torch.Tensor):
g.edata[edge_feat_name] = self.row_normalize(g.edata[edge_feat_name])
else:
for etype in g.edata[edge_feat_name].keys():
g.edges[etype].data[edge_feat_name] = \
self.row_normalize(g.edges[etype].data[edge_feat_name])
return g
[docs]class FeatMask(BaseTransform):
r"""Randomly mask columns of the node and edge feature tensors, as described in `Graph
Contrastive Learning with Augmentations <https://arxiv.org/abs/2010.13902>`__.
Parameters
----------
p : float, optional
Probability of masking a column of a feature tensor. Default: `0.5`.
node_feat_names : list[str], optional
The names of the node feature tensors to be masked. Default: `None`, which will
not mask any node feature tensor.
edge_feat_names : list[str], optional
The names of the edge features to be masked. Default: `None`, which will not mask
any edge feature tensor.
Example
-------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> from dgl import FeatMask
Case1 : Mask node and edge feature tensors of a homogeneous graph.
>>> transform = FeatMask(node_feat_names=['h'], edge_feat_names=['w'])
>>> g = dgl.rand_graph(5, 10)
>>> g.ndata['h'] = torch.ones((g.num_nodes(), 10))
>>> g.edata['w'] = torch.ones((g.num_edges(), 10))
>>> g = transform(g)
>>> print(g.ndata['h'])
tensor([[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.]])
>>> print(g.edata['w'])
tensor([[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.]])
Case2 : Mask node and edge feature tensors of a heterogeneous graph.
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),
... ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))
... })
>>> g.ndata['h'] = {'game': torch.ones(2, 5), 'player': torch.ones(3, 5)}
>>> g.edata['w'] = {('user', 'follows', 'user'): torch.ones(2, 5)}
>>> print(g.ndata['h']['game'])
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
>>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
>>> g = transform(g)
>>> print(g.ndata['h']['game'])
tensor([[1., 1., 0., 1., 0.],
[1., 1., 0., 1., 0.]])
>>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[0., 1., 0., 1., 0.],
[0., 1., 0., 1., 0.]])
"""
def __init__(self, p=0.5, node_feat_names=None, edge_feat_names=None):
self.p = p
self.node_feat_names = [] if node_feat_names is None else node_feat_names
self.edge_feat_names = [] if edge_feat_names is None else edge_feat_names
self.dist = Bernoulli(p)
def __call__(self, g):
# Fast path
if self.p == 0:
return g
for node_feat_name in self.node_feat_names:
if isinstance(g.ndata[node_feat_name], torch.Tensor):
feat_mask = self.dist.sample(torch.Size([g.ndata[node_feat_name].shape[-1], ]))
g.ndata[node_feat_name][:, feat_mask.bool().to(g.device)] = 0
else:
for ntype in g.ndata[node_feat_name].keys():
mask_shape = g.ndata[node_feat_name][ntype].shape[-1]
feat_mask = self.dist.sample(torch.Size([mask_shape, ]))
g.ndata[node_feat_name][ntype][:, feat_mask.bool().to(g.device)] = 0
for edge_feat_name in self.edge_feat_names:
if isinstance(g.edata[edge_feat_name], torch.Tensor):
feat_mask = self.dist.sample(torch.Size([g.edata[edge_feat_name].shape[-1], ]))
g.edata[edge_feat_name][:, feat_mask.bool().to(g.device)] = 0
else:
for etype in g.edata[edge_feat_name].keys():
mask_shape = g.edata[edge_feat_name][etype].shape[-1]
feat_mask = self.dist.sample(torch.Size([mask_shape, ]))
g.edata[edge_feat_name][etype][:, feat_mask.bool().to(g.device)] = 0
return g
[docs]class RandomWalkPE(BaseTransform):
r"""Random Walk Positional Encoding, as introduced in
`Graph Neural Networks with Learnable Structural and Positional Representations
<https://arxiv.org/abs/2110.07875>`__
This module only works for homogeneous graphs.
Parameters
----------
k : int
Number of random walk steps. The paper found the best value to be 16 and 20
for two experiments.
feat_name : str, optional
Name to store the computed positional encodings in ndata.
eweight_name : str, optional
Name to retrieve the edge weights. Default: None, not using the edge weights.
Example
-------
>>> import dgl
>>> from dgl import RandomWalkPE
>>> transform = RandomWalkPE(k=2)
>>> g = dgl.graph(([0, 1, 1], [1, 1, 0]))
>>> g = transform(g)
>>> print(g.ndata['PE'])
tensor([[0.0000, 0.5000],
[0.5000, 0.7500]])
"""
def __init__(self, k, feat_name='PE', eweight_name=None):
self.k = k
self.feat_name = feat_name
self.eweight_name = eweight_name
def __call__(self, g):
PE = functional.random_walk_pe(g, k=self.k, eweight_name=self.eweight_name)
g.ndata[self.feat_name] = F.copy_to(PE, g.device)
return g
[docs]class LaplacianPE(BaseTransform):
r"""Laplacian Positional Encoding, as introduced in
`Benchmarking Graph Neural Networks
<https://arxiv.org/abs/2003.00982>`__
This module only works for homogeneous bidirected graphs.
Parameters
----------
k : int
Number of smallest non-trivial eigenvectors to use for positional encoding
(smaller than the number of nodes).
feat_name : str, optional
Name to store the computed positional encodings in ndata.
Example
-------
>>> import dgl
>>> from dgl import LaplacianPE
>>> transform = LaplacianPE(k=3)
>>> g = dgl.rand_graph(5, 10)
>>> g = transform(g)
>>> print(g.ndata['PE'])
tensor([[ 0.0000, -0.3646, 0.3646],
[ 0.0000, 0.2825, -0.2825],
[ 1.0000, -0.6315, 0.6315],
[ 0.0000, 0.3739, -0.3739],
[ 0.0000, -0.1663, 0.1663]])
"""
def __init__(self, k, feat_name='PE'):
self.k = k
self.feat_name = feat_name
def __call__(self, g):
PE = functional.laplacian_pe(g, k=self.k)
g.ndata[self.feat_name] = F.copy_to(PE, g.device)
return g
[docs]class AddSelfLoop(BaseTransform):
r"""Add self-loops for each node in the graph and return a new graph.
For heterogeneous graphs, self-loops are added only for edge types with same
source and destination node types.
Parameters
----------
allow_duplicate : bool, optional
If False, it will first remove self-loops to prevent duplicate self-loops.
new_etypes : bool, optional
If True, it will add an edge type 'self' per node type, which holds self-loops.
Example
-------
>>> import dgl
>>> from dgl import AddSelfLoop
Case1: Add self-loops for a homogeneous graph
>>> transform = AddSelfLoop()
>>> g = dgl.graph(([1, 1], [1, 2]))
>>> new_g = transform(g)
>>> print(new_g.edges())
(tensor([1, 0, 1, 2]), tensor([2, 0, 1, 2]))
Case2: Add self-loops for a heterogeneous graph
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): ([0], [1]),
... ('user', 'follows', 'user'): ([1], [2])
... })
>>> new_g = transform(g)
>>> print(new_g.edges(etype='plays'))
(tensor([0]), tensor([1]))
>>> print(new_g.edges(etype='follows'))
(tensor([1, 0, 1, 2]), tensor([2, 0, 1, 2]))
Case3: Add self-etypes for a heterogeneous graph
>>> transform = AddSelfLoop(new_etypes=True)
>>> new_g = transform(g)
>>> print(new_g.edges(etype='follows'))
(tensor([1, 0, 1, 2]), tensor([2, 0, 1, 2]))
>>> print(new_g.edges(etype=('game', 'self', 'game')))
(tensor([0, 1]), tensor([0, 1]))
"""
def __init__(self, allow_duplicate=False, new_etypes=False):
self.allow_duplicate = allow_duplicate
self.new_etypes = new_etypes
def transform_etype(self, c_etype, g):
r"""
Description
-----------
Transform the graph corresponding to a canonical edge type.
Parameters
----------
c_etype : tuple of str
A canonical edge type.
g : DGLGraph
The graph.
Returns
-------
DGLGraph
The transformed graph.
"""
utype, _, vtype = c_etype
if utype != vtype:
return g
if not self.allow_duplicate:
g = functional.remove_self_loop(g, etype=c_etype)
return functional.add_self_loop(g, etype=c_etype)
def __call__(self, g):
for c_etype in g.canonical_etypes:
g = self.transform_etype(c_etype, g)
if self.new_etypes:
device = g.device
idtype = g.idtype
data_dict = dict()
# Add self etypes
for ntype in g.ntypes:
nids = F.arange(0, g.num_nodes(ntype), idtype, device)
data_dict[(ntype, 'self', ntype)] = (nids, nids)
# Copy edges
for c_etype in g.canonical_etypes:
data_dict[c_etype] = g.edges(etype=c_etype)
g = update_graph_structure(g, data_dict)
return g
[docs]class RemoveSelfLoop(BaseTransform):
r"""Remove self-loops for each node in the graph and return a new graph.
For heterogeneous graphs, this operation only applies to edge types with same
source and destination node types.
Example
-------
>>> import dgl
>>> from dgl import RemoveSelfLoop
Case1: Remove self-loops for a homogeneous graph
>>> transform = RemoveSelfLoop()
>>> g = dgl.graph(([1, 1], [1, 2]))
>>> new_g = transform(g)
>>> print(new_g.edges())
(tensor([1]), tensor([2]))
Case2: Remove self-loops for a heterogeneous graph
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): ([0, 1], [1, 1]),
... ('user', 'follows', 'user'): ([1, 2], [2, 2])
... })
>>> new_g = transform(g)
>>> print(new_g.edges(etype='plays'))
(tensor([0, 1]), tensor([1, 1]))
>>> print(new_g.edges(etype='follows'))
(tensor([1]), tensor([2]))
"""
def transform_etype(self, c_etype, g):
r"""Transform the graph corresponding to a canonical edge type.
Parameters
----------
c_etype : tuple of str
A canonical edge type.
g : DGLGraph
The graph.
Returns
-------
DGLGraph
The transformed graph.
"""
utype, _, vtype = c_etype
if utype == vtype:
g = functional.remove_self_loop(g, etype=c_etype)
return g
def __call__(self, g):
for c_etype in g.canonical_etypes:
g = self.transform_etype(c_etype, g)
return g
[docs]class AddReverse(BaseTransform):
r"""Add a reverse edge :math:`(i,j)` for each edge :math:`(j,i)` in the input graph and
return a new graph.
For a heterogeneous graph, it adds a "reverse" edge type for each edge type
to hold the reverse edges. For example, for a canonical edge type ('A', 'r', 'B'),
it adds a canonical edge type ('B', 'rev_r', 'A').
Parameters
----------
copy_edata : bool, optional
If True, the features of the reverse edges will be identical to the original ones.
sym_new_etype : bool, optional
If False, it will not add a reverse edge type if the source and destination node type
in a canonical edge type are identical. Instead, it will directly add edges to the
original edge type.
Example
-------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> from dgl import AddReverse
Case1: Add reverse edges for a homogeneous graph
>>> transform = AddReverse()
>>> g = dgl.graph(([0], [1]))
>>> g.edata['w'] = torch.ones(1, 2)
>>> new_g = transform(g)
>>> print(new_g.edges())
(tensor([0, 1]), tensor([1, 0]))
>>> print(new_g.edata['w'])
tensor([[1., 1.],
[0., 0.]])
Case2: Add reverse edges for a homogeneous graph and copy edata
>>> transform = AddReverse(copy_edata=True)
>>> new_g = transform(g)
>>> print(new_g.edata['w'])
tensor([[1., 1.],
[1., 1.]])
Case3: Add reverse edges for a heterogeneous graph
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): ([0, 1], [1, 1]),
... ('user', 'follows', 'user'): ([1, 2], [2, 2])
... })
>>> new_g = transform(g)
>>> print(new_g.canonical_etypes)
[('game', 'rev_plays', 'user'), ('user', 'follows', 'user'), ('user', 'plays', 'game')]
>>> print(new_g.edges(etype='rev_plays'))
(tensor([1, 1]), tensor([0, 1]))
>>> print(new_g.edges(etype='follows'))
(tensor([1, 2, 2, 2]), tensor([2, 2, 1, 2]))
"""
def __init__(self, copy_edata=False, sym_new_etype=False):
self.copy_edata = copy_edata
self.sym_new_etype = sym_new_etype
def transform_symmetric_etype(self, c_etype, g, data_dict):
r"""Transform the graph corresponding to a symmetric canonical edge type.
Parameters
----------
c_etype : tuple of str
A canonical edge type.
g : DGLGraph
The graph.
data_dict : dict
The edge data to update.
"""
if self.sym_new_etype:
self.transform_asymmetric_etype(c_etype, g, data_dict)
else:
src, dst = g.edges(etype=c_etype)
src, dst = F.cat([src, dst], dim=0), F.cat([dst, src], dim=0)
data_dict[c_etype] = (src, dst)
def transform_asymmetric_etype(self, c_etype, g, data_dict):
r"""Transform the graph corresponding to an asymmetric canonical edge type.
Parameters
----------
c_etype : tuple of str
A canonical edge type.
g : DGLGraph
The graph.
data_dict : dict
The edge data to update.
"""
utype, etype, vtype = c_etype
src, dst = g.edges(etype=c_etype)
data_dict.update({
c_etype: (src, dst),
(vtype, 'rev_{}'.format(etype), utype): (dst, src)
})
def transform_etype(self, c_etype, g, data_dict):
r"""Transform the graph corresponding to a canonical edge type.
Parameters
----------
c_etype : tuple of str
A canonical edge type.
g : DGLGraph
The graph.
data_dict : dict
The edge data to update.
"""
utype, _, vtype = c_etype
if utype == vtype:
self.transform_symmetric_etype(c_etype, g, data_dict)
else:
self.transform_asymmetric_etype(c_etype, g, data_dict)
def __call__(self, g):
data_dict = dict()
for c_etype in g.canonical_etypes:
self.transform_etype(c_etype, g, data_dict)
new_g = update_graph_structure(g, data_dict, copy_edata=False)
# Copy and expand edata
for c_etype in g.canonical_etypes:
utype, etype, vtype = c_etype
if utype != vtype or self.sym_new_etype:
rev_c_etype = (vtype, 'rev_{}'.format(etype), utype)
for key, feat in g.edges[c_etype].data.items():
new_g.edges[c_etype].data[key] = feat
if self.copy_edata:
new_g.edges[rev_c_etype].data[key] = feat
else:
for key, feat in g.edges[c_etype].data.items():
new_feat = feat if self.copy_edata else F.zeros(
F.shape(feat), F.dtype(feat), F.context(feat))
new_g.edges[c_etype].data[key] = F.cat([feat, new_feat], dim=0)
return new_g
[docs]class ToSimple(BaseTransform):
r"""Convert a graph to a simple graph without parallel edges and return a new graph.
Parameters
----------
return_counts : str, optional
The edge feature name to hold the edge count in the original graph.
aggregator : str, optional
The way to coalesce features of duplicate edges.
* ``'arbitrary'``: select arbitrarily from one of the duplicate edges
* ``'sum'``: take the sum over the duplicate edges
* ``'mean'``: take the mean over the duplicate edges
Example
-------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> from dgl import ToSimple
Case1: Convert a homogeneous graph to a simple graph
>>> transform = ToSimple()
>>> g = dgl.graph(([0, 1, 1], [1, 2, 2]))
>>> g.edata['w'] = torch.tensor([[0.1], [0.2], [0.3]])
>>> sg = transform(g)
>>> print(sg.edges())
(tensor([0, 1]), tensor([1, 2]))
>>> print(sg.edata['count'])
tensor([1, 2])
>>> print(sg.edata['w'])
tensor([[0.1000], [0.2000]])
Case2: Convert a heterogeneous graph to a simple graph
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2]),
... ('user', 'plays', 'game'): ([0, 1, 0], [1, 1, 1])
... })
>>> sg = transform(g)
>>> print(sg.edges(etype='follows'))
(tensor([0, 1]), tensor([1, 2]))
>>> print(sg.edges(etype='plays'))
(tensor([0, 1]), tensor([1, 1]))
"""
def __init__(self, return_counts='count', aggregator='arbitrary'):
self.return_counts = return_counts
self.aggregator = aggregator
def __call__(self, g):
return functional.to_simple(g,
return_counts=self.return_counts,
copy_edata=True,
aggregator=self.aggregator)
[docs]class LineGraph(BaseTransform):
r"""Return the line graph of the input graph.
The line graph :math:`L(G)` of a given graph :math:`G` is a graph where
the nodes in :math:`L(G)` correspond to the edges in :math:`G`. For a pair
of edges :math:`(u, v)` and :math:`(v, w)` in :math:`G`, there will be an
edge from the node corresponding to :math:`(u, v)` to the node corresponding to
:math:`(v, w)` in :math:`L(G)`.
This module only works for homogeneous graphs.
Parameters
----------
backtracking : bool, optional
If False, there will be an edge from the line graph node corresponding to
:math:`(u, v)` to the line graph node corresponding to :math:`(v, u)`.
Example
-------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> from dgl import LineGraph
Case1: Backtracking is True
>>> transform = LineGraph()
>>> g = dgl.graph(([0, 1, 1], [1, 0, 2]))
>>> g.ndata['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.edata['w'] = torch.tensor([[0.], [0.1], [0.2]])
>>> new_g = transform(g)
>>> print(new_g)
Graph(num_nodes=3, num_edges=3,
ndata_schemes={'w': Scheme(shape=(1,), dtype=torch.float32)}
edata_schemes={})
>>> print(new_g.edges())
(tensor([0, 0, 1]), tensor([1, 2, 0]))
Case2: Backtracking is False
>>> transform = LineGraph(backtracking=False)
>>> new_g = transform(g)
>>> print(new_g.edges())
(tensor([0]), tensor([2]))
"""
def __init__(self, backtracking=True):
self.backtracking = backtracking
def __call__(self, g):
return functional.line_graph(g, backtracking=self.backtracking, shared=True)
[docs]class KHopGraph(BaseTransform):
r"""Return the graph whose edges connect the :math:`k`-hop neighbors of the original graph.
This module only works for homogeneous graphs.
Parameters
----------
k : int
The number of hops.
Example
-------
>>> import dgl
>>> from dgl import KHopGraph
>>> transform = KHopGraph(2)
>>> g = dgl.graph(([0, 1], [1, 2]))
>>> new_g = transform(g)
>>> print(new_g.edges())
(tensor([0]), tensor([2]))
"""
def __init__(self, k):
self.k = k
def __call__(self, g):
return functional.khop_graph(g, self.k)
[docs]class Compose(BaseTransform):
r"""Create a transform composed of multiple transforms in sequence.
Parameters
----------
transforms : list of Callable
A list of transform objects to apply in order. A transform object should inherit
:class:`~dgl.BaseTransform` and implement :func:`~dgl.BaseTransform.__call__`.
Example
-------
>>> import dgl
>>> from dgl import transforms as T
>>> g = dgl.graph(([0, 0], [1, 1]))
>>> transform = T.Compose([T.ToSimple(), T.AddReverse()])
>>> new_g = transform(g)
>>> print(new_g.edges())
(tensor([0, 1]), tensor([1, 0]))
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, g):
for transform in self.transforms:
g = transform(g)
return g
def __repr__(self):
args = [' ' + str(transform) for transform in self.transforms]
return self.__class__.__name__ + '([\n' + ',\n'.join(args) + '\n])'
[docs]class GCNNorm(BaseTransform):
r"""Apply symmetric adjacency normalization to an input graph and save the result edge
weights, as described in `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__.
For a heterogeneous graph, this only applies to symmetric canonical edge types, whose source
and destination node types are identical.
Parameters
----------
eweight_name : str, optional
:attr:`edata` name to retrieve and store edge weights. The edge weights are optional.
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import GCNNorm
>>> transform = GCNNorm()
>>> g = dgl.graph(([0, 1, 2], [0, 0, 1]))
Case1: Transform an unweighted graph
>>> g = transform(g)
>>> print(g.edata['w'])
tensor([0.5000, 0.7071, 0.0000])
Case2: Transform a weighted graph
>>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3])
>>> g = transform(g)
>>> print(g.edata['w'])
tensor([0.3333, 0.6667, 0.0000])
"""
def __init__(self, eweight_name='w'):
self.eweight_name = eweight_name
def calc_etype(self, c_etype, g):
r"""
Description
-----------
Get edge weights for an edge type.
"""
ntype = c_etype[0]
with g.local_scope():
if self.eweight_name in g.edges[c_etype].data:
g.update_all(fn.copy_e(self.eweight_name, 'm'), fn.sum('m', 'deg'), etype=c_etype)
deg_inv_sqrt = 1. / F.sqrt(g.nodes[ntype].data['deg'])
g.nodes[ntype].data['w'] = F.replace_inf_with_zero(deg_inv_sqrt)
g.apply_edges(lambda edge: {'w': edge.src['w'] * edge.data[self.eweight_name] *
edge.dst['w']},
etype=c_etype)
else:
deg = g.in_degrees(etype=c_etype)
deg_inv_sqrt = 1. / F.sqrt(F.astype(deg, F.float32))
g.nodes[ntype].data['w'] = F.replace_inf_with_zero(deg_inv_sqrt)
g.apply_edges(lambda edges: {'w': edges.src['w'] * edges.dst['w']}, etype=c_etype)
return g.edges[c_etype].data['w']
def __call__(self, g):
result = dict()
for c_etype in g.canonical_etypes:
utype, _, vtype = c_etype
if utype == vtype:
result[c_etype] = self.calc_etype(c_etype, g)
for c_etype, eweight in result.items():
g.edges[c_etype].data[self.eweight_name] = eweight
return g
[docs]class PPR(BaseTransform):
r"""Apply personalized PageRank (PPR) to an input graph for diffusion, as introduced in
`The pagerank citation ranking: Bringing order to the web
<http://ilpubs.stanford.edu:8090/422/>`__.
A sparsification will be applied to the weighted adjacency matrix after diffusion.
Specifically, edges whose weight is below a threshold will be dropped.
This module only works for homogeneous graphs.
Parameters
----------
alpha : float, optional
Restart probability, which commonly lies in :math:`[0.05, 0.2]`.
eweight_name : str, optional
:attr:`edata` name to retrieve and store edge weights. If it does
not exist in an input graph, this module initializes a weight of 1
for all edges. The edge weights should be a tensor of shape :math:`(E)`,
where E is the number of edges.
eps : float, optional
The threshold to preserve edges in sparsification after diffusion. Edges of a
weight smaller than eps will be dropped.
avg_degree : int, optional
The desired average node degree of the result graph. This is the other way to
control the sparsity of the result graph and will only be effective if
:attr:`eps` is not given.
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import PPR
>>> transform = PPR(avg_degree=2)
>>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]))
>>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])
>>> new_g = transform(g)
>>> print(new_g.edata['w'])
tensor([0.1500, 0.1500, 0.1500, 0.0255, 0.0163, 0.1500, 0.0638, 0.0383, 0.1500,
0.0510, 0.0217, 0.1500])
"""
def __init__(self, alpha=0.15, eweight_name='w', eps=None, avg_degree=5):
self.alpha = alpha
self.eweight_name = eweight_name
self.eps = eps
self.avg_degree = avg_degree
def get_eps(self, num_nodes, mat):
r"""Get the threshold for graph sparsification.
"""
if self.eps is None:
# Infer from self.avg_degree
if self.avg_degree > num_nodes:
return float('-inf')
sorted_weights = torch.sort(mat.flatten(), descending=True).values
return sorted_weights[self.avg_degree * num_nodes - 1]
else:
return self.eps
def __call__(self, g):
# Step1: PPR diffusion
# (α - 1) A
device = g.device
eweight = (self.alpha - 1) * g.edata.get(self.eweight_name, F.ones(
(g.num_edges(),), F.float32, device))
num_nodes = g.num_nodes()
mat = F.zeros((num_nodes, num_nodes), F.float32, device)
src, dst = g.edges()
src, dst = F.astype(src, F.int64), F.astype(dst, F.int64)
mat[dst, src] = eweight
# I_n + (α - 1) A
nids = F.astype(g.nodes(), F.int64)
mat[nids, nids] = mat[nids, nids] + 1
# α (I_n + (α - 1) A)^-1
diff_mat = self.alpha * F.inverse(mat)
# Step2: sparsification
num_nodes = g.num_nodes()
eps = self.get_eps(num_nodes, diff_mat)
dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t()
data_dict = {g.canonical_etypes[0]: (src, dst)}
new_g = update_graph_structure(g, data_dict, copy_edata=False)
new_g.edata[self.eweight_name] = diff_mat[dst, src]
return new_g
def is_bidirected(g):
"""Return whether the graph is a bidirected graph.
A graph is bidirected if for any edge :math:`(u, v)` in :math:`G` with weight :math:`w`,
there exists an edge :math:`(v, u)` in :math:`G` with the same weight.
"""
src, dst = g.edges()
num_nodes = g.num_nodes()
# Sort first by src then dst
idx_src_dst = src * num_nodes + dst
perm_src_dst = F.argsort(idx_src_dst, dim=0, descending=False)
src1, dst1 = src[perm_src_dst], dst[perm_src_dst]
# Sort first by dst then src
idx_dst_src = dst * num_nodes + src
perm_dst_src = F.argsort(idx_dst_src, dim=0, descending=False)
src2, dst2 = src[perm_dst_src], dst[perm_dst_src]
return F.allclose(src1, dst2) and F.allclose(src2, dst1)
# pylint: disable=C0103
[docs]class HeatKernel(BaseTransform):
r"""Apply heat kernel to an input graph for diffusion, as introduced in
`Diffusion kernels on graphs and other discrete structures
<https://www.ml.cmu.edu/research/dap-papers/kondor-diffusion-kernels.pdf>`__.
A sparsification will be applied to the weighted adjacency matrix after diffusion.
Specifically, edges whose weight is below a threshold will be dropped.
This module only works for homogeneous graphs.
Parameters
----------
t : float, optional
Diffusion time, which commonly lies in :math:`[2, 10]`.
eweight_name : str, optional
:attr:`edata` name to retrieve and store edge weights. If it does
not exist in an input graph, this module initializes a weight of 1
for all edges. The edge weights should be a tensor of shape :math:`(E)`,
where E is the number of edges.
eps : float, optional
The threshold to preserve edges in sparsification after diffusion. Edges of a
weight smaller than eps will be dropped.
avg_degree : int, optional
The desired average node degree of the result graph. This is the other way to
control the sparsity of the result graph and will only be effective if
:attr:`eps` is not given.
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import HeatKernel
>>> transform = HeatKernel(avg_degree=2)
>>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]))
>>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])
>>> new_g = transform(g)
>>> print(new_g.edata['w'])
tensor([0.1353, 0.1353, 0.1353, 0.0541, 0.0406, 0.1353, 0.1353, 0.0812, 0.1353,
0.1083, 0.0541, 0.1353])
"""
def __init__(self, t=2., eweight_name='w', eps=None, avg_degree=5):
self.t = t
self.eweight_name = eweight_name
self.eps = eps
self.avg_degree = avg_degree
def get_eps(self, num_nodes, mat):
r"""Get the threshold for graph sparsification.
"""
if self.eps is None:
# Infer from self.avg_degree
if self.avg_degree > num_nodes:
return float('-inf')
sorted_weights = torch.sort(mat.flatten(), descending=True).values
return sorted_weights[self.avg_degree * num_nodes - 1]
else:
return self.eps
def __call__(self, g):
# Step1: heat kernel diffusion
# t A
device = g.device
eweight = self.t * g.edata.get(self.eweight_name, F.ones(
(g.num_edges(),), F.float32, device))
num_nodes = g.num_nodes()
mat = F.zeros((num_nodes, num_nodes), F.float32, device)
src, dst = g.edges()
src, dst = F.astype(src, F.int64), F.astype(dst, F.int64)
mat[dst, src] = eweight
# t (A - I_n)
nids = F.astype(g.nodes(), F.int64)
mat[nids, nids] = mat[nids, nids] - self.t
if is_bidirected(g):
e, V = torch.linalg.eigh(mat, UPLO='U')
diff_mat = V @ torch.diag(e.exp()) @ V.t()
else:
diff_mat_np = expm(mat.cpu().numpy())
diff_mat = torch.Tensor(diff_mat_np).to(device)
# Step2: sparsification
num_nodes = g.num_nodes()
eps = self.get_eps(num_nodes, diff_mat)
dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t()
data_dict = {g.canonical_etypes[0]: (src, dst)}
new_g = update_graph_structure(g, data_dict, copy_edata=False)
new_g.edata[self.eweight_name] = diff_mat[dst, src]
return new_g
[docs]class GDC(BaseTransform):
r"""Apply graph diffusion convolution (GDC) to an input graph, as introduced in
`Diffusion Improves Graph Learning <https://www.in.tum.de/daml/gdc/>`__.
A sparsification will be applied to the weighted adjacency matrix after diffusion.
Specifically, edges whose weight is below a threshold will be dropped.
This module only works for homogeneous graphs.
Parameters
----------
coefs : list[float], optional
List of coefficients. :math:`\theta_k` for each power of the adjacency matrix.
eweight_name : str, optional
:attr:`edata` name to retrieve and store edge weights. If it does
not exist in an input graph, this module initializes a weight of 1
for all edges. The edge weights should be a tensor of shape :math:`(E)`,
where E is the number of edges.
eps : float, optional
The threshold to preserve edges in sparsification after diffusion. Edges of a
weight smaller than eps will be dropped.
avg_degree : int, optional
The desired average node degree of the result graph. This is the other way to
control the sparsity of the result graph and will only be effective if
:attr:`eps` is not given.
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import GDC
>>> transform = GDC([0.3, 0.2, 0.1], avg_degree=2)
>>> g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]))
>>> g.edata['w'] = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])
>>> new_g = transform(g)
>>> print(new_g.edata['w'])
tensor([0.3000, 0.3000, 0.0200, 0.3000, 0.0400, 0.3000, 0.1000, 0.0600, 0.3000,
0.0800, 0.0200, 0.3000])
"""
def __init__(self, coefs, eweight_name='w', eps=None, avg_degree=5):
self.coefs = coefs
self.eweight_name = eweight_name
self.eps = eps
self.avg_degree = avg_degree
def get_eps(self, num_nodes, mat):
r"""Get the threshold for graph sparsification."""
if self.eps is None:
# Infer from self.avg_degree
if self.avg_degree > num_nodes:
return float('-inf')
sorted_weights = torch.sort(mat.flatten(), descending=True).values
return sorted_weights[self.avg_degree * num_nodes - 1]
else:
return self.eps
def __call__(self, g):
# Step1: diffusion
# A
device = g.device
eweight = g.edata.get(self.eweight_name, F.ones(
(g.num_edges(),), F.float32, device))
num_nodes = g.num_nodes()
adj = F.zeros((num_nodes, num_nodes), F.float32, device)
src, dst = g.edges()
src, dst = F.astype(src, F.int64), F.astype(dst, F.int64)
adj[dst, src] = eweight
# theta_0 I_n
mat = torch.eye(num_nodes, device=device)
diff_mat = self.coefs[0] * mat
# add theta_k A^k
for coef in self.coefs[1:]:
mat = mat @ adj
diff_mat += coef * mat
# Step2: sparsification
num_nodes = g.num_nodes()
eps = self.get_eps(num_nodes, diff_mat)
dst, src = (diff_mat >= eps).nonzero(as_tuple=False).t()
data_dict = {g.canonical_etypes[0]: (src, dst)}
new_g = update_graph_structure(g, data_dict, copy_edata=False)
new_g.edata[self.eweight_name] = diff_mat[dst, src]
return new_g
[docs]class NodeShuffle(BaseTransform):
r"""Randomly shuffle the nodes.
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import NodeShuffle
>>> transform = NodeShuffle()
>>> g = dgl.graph(([0, 1], [1, 2]))
>>> g.ndata['h1'] = torch.tensor([[1., 2.], [3., 4.], [5., 6.]])
>>> g.ndata['h2'] = torch.tensor([[7., 8.], [9., 10.], [11., 12.]])
>>> g = transform(g)
>>> print(g.ndata['h1'])
tensor([[5., 6.],
[3., 4.],
[1., 2.]])
>>> print(g.ndata['h2'])
tensor([[11., 12.],
[ 9., 10.],
[ 7., 8.]])
"""
def __call__(self, g):
for ntype in g.ntypes:
nids = F.astype(g.nodes(ntype), F.int64)
perm = F.rand_shuffle(nids)
for key, feat in g.nodes[ntype].data.items():
g.nodes[ntype].data[key] = feat[perm]
return g
# pylint: disable=C0103
[docs]class DropNode(BaseTransform):
r"""Randomly drop nodes, as described in
`Graph Contrastive Learning with Augmentations <https://arxiv.org/abs/2010.13902>`__.
Parameters
----------
p : float, optional
Probability of a node to be dropped.
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import DropNode
>>> transform = DropNode()
>>> g = dgl.rand_graph(5, 20)
>>> g.ndata['h'] = torch.arange(g.num_nodes())
>>> g.edata['h'] = torch.arange(g.num_edges())
>>> new_g = transform(g)
>>> print(new_g)
Graph(num_nodes=3, num_edges=7,
ndata_schemes={'h': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'h': Scheme(shape=(), dtype=torch.int64)})
>>> print(new_g.ndata['h'])
tensor([0, 1, 2])
>>> print(new_g.edata['h'])
tensor([0, 6, 14, 5, 17, 3, 11])
"""
def __init__(self, p=0.5):
self.p = p
self.dist = Bernoulli(p)
def __call__(self, g):
# Fast path
if self.p == 0:
return g
for ntype in g.ntypes:
samples = self.dist.sample(torch.Size([g.num_nodes(ntype)]))
nids_to_remove = g.nodes(ntype)[samples.bool().to(g.device)]
g.remove_nodes(nids_to_remove, ntype=ntype)
return g
# pylint: disable=C0103
[docs]class DropEdge(BaseTransform):
r"""Randomly drop edges, as described in
`DropEdge: Towards Deep Graph Convolutional Networks on Node Classification
<https://arxiv.org/abs/1907.10903>`__ and `Graph Contrastive Learning with Augmentations
<https://arxiv.org/abs/2010.13902>`__.
Parameters
----------
p : float, optional
Probability of an edge to be dropped.
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import DropEdge
>>> transform = DropEdge()
>>> g = dgl.rand_graph(5, 20)
>>> g.edata['h'] = torch.arange(g.num_edges())
>>> new_g = transform(g)
>>> print(new_g)
Graph(num_nodes=5, num_edges=12,
ndata_schemes={}
edata_schemes={'h': Scheme(shape=(), dtype=torch.int64)})
>>> print(new_g.edata['h'])
tensor([0, 1, 3, 7, 8, 10, 11, 12, 13, 15, 18, 19])
"""
def __init__(self, p=0.5):
self.p = p
self.dist = Bernoulli(p)
def __call__(self, g):
# Fast path
if self.p == 0:
return g
for c_etype in g.canonical_etypes:
samples = self.dist.sample(torch.Size([g.num_edges(c_etype)]))
eids_to_remove = g.edges(form='eid', etype=c_etype)[samples.bool().to(g.device)]
g.remove_edges(eids_to_remove, etype=c_etype)
return g
[docs]class AddEdge(BaseTransform):
r"""Randomly add edges, as described in `Graph Contrastive Learning with Augmentations
<https://arxiv.org/abs/2010.13902>`__.
Parameters
----------
ratio : float, optional
Number of edges to add divided by the number of existing edges.
Example
-------
>>> import dgl
>>> from dgl import AddEdge
>>> transform = AddEdge()
>>> g = dgl.rand_graph(5, 20)
>>> new_g = transform(g)
>>> print(new_g.num_edges())
24
"""
def __init__(self, ratio=0.2):
self.ratio = ratio
def __call__(self, g):
# Fast path
if self.ratio == 0.:
return g
device = g.device
idtype = g.idtype
for c_etype in g.canonical_etypes:
utype, _, vtype = c_etype
num_edges_to_add = int(g.num_edges(c_etype) * self.ratio)
src = F.randint([num_edges_to_add], idtype, device, low=0, high=g.num_nodes(utype))
dst = F.randint([num_edges_to_add], idtype, device, low=0, high=g.num_nodes(vtype))
g.add_edges(src, dst, etype=c_etype)
return g
[docs]class SIGNDiffusion(BaseTransform):
r"""The diffusion operator from `SIGN: Scalable Inception Graph Neural Networks
<https://arxiv.org/abs/2004.11198>`__
It performs node feature diffusion with :math:`TX, \cdots, T^{k}X`, where :math:`T`
is a diffusion matrix and :math:`X` is the input node features.
Specifically, this module provides four options for :math:`T`.
**raw**: raw adjacency matrix :math:`A`
**rw**: random walk (row-normalized) adjacency matrix :math:`D^{-1}A`, where
:math:`D` is the degree matrix.
**gcn**: symmetrically normalized adjacency matrix used by
`GCN <https://arxiv.org/abs/1609.02907>`__, :math:`D^{-1/2}AD^{-1/2}`
**ppr**: approximate personalized PageRank used by
`APPNP <https://arxiv.org/abs/1810.05997>`__
.. math::
H^{0} &= X
H^{l+1} &= (1-\alpha)\left(D^{-1/2}AD^{-1/2} H^{l}\right) + \alpha X
This module only works for homogeneous graphs.
Parameters
----------
k : int
The maximum number of times for node feature diffusion.
in_feat_name : str, optional
:attr:`g.ndata[{in_feat_name}]` should store the input node features. Default: 'feat'
out_feat_name : str, optional
:attr:`g.ndata[{out_feat_name}_i]` will store the result of diffusing
input node features for i times. Default: 'out_feat'
eweight_name : str, optional
Name to retrieve edge weights from :attr:`g.edata`. Default: None,
treating the graph as unweighted.
diffuse_op : str, optional
The diffusion operator to use, which can be 'raw', 'rw', 'gcn', or 'ppr'.
Default: 'raw'
alpha : float, optional
Restart probability if :attr:`diffuse_op` is :attr:`'ppr'`,
which commonly lies in :math:`[0.05, 0.2]`. Default: 0.2
Example
-------
>>> import dgl
>>> import torch
>>> from dgl import SIGNDiffusion
>>> transform = SIGNDiffusion(k=2, eweight_name='w')
>>> num_nodes = 5
>>> num_edges = 20
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> g.ndata['feat'] = torch.randn(num_nodes, 10)
>>> g.edata['w'] = torch.randn(num_edges)
>>> transform(g)
Graph(num_nodes=5, num_edges=20,
ndata_schemes={'feat': Scheme(shape=(10,), dtype=torch.float32),
'out_feat_1': Scheme(shape=(10,), dtype=torch.float32),
'out_feat_2': Scheme(shape=(10,), dtype=torch.float32)}
edata_schemes={'w': Scheme(shape=(), dtype=torch.float32)})
"""
def __init__(self,
k,
in_feat_name='feat',
out_feat_name='out_feat',
eweight_name=None,
diffuse_op='raw',
alpha=0.2):
self.k = k
self.in_feat_name = in_feat_name
self.out_feat_name = out_feat_name
self.eweight_name = eweight_name
self.diffuse_op = diffuse_op
self.alpha = alpha
if diffuse_op == 'raw':
self.diffuse = self.raw
elif diffuse_op == 'rw':
self.diffuse = self.rw
elif diffuse_op == 'gcn':
self.diffuse = self.gcn
elif diffuse_op == 'ppr':
self.diffuse = self.ppr
else:
raise DGLError("Expect diffuse_op to be from ['raw', 'rw', 'gcn', 'ppr'], \
got {}".format(diffuse_op))
def __call__(self, g):
feat_list = self.diffuse(g)
for i in range(1, self.k + 1):
g.ndata[self.out_feat_name + '_' + str(i)] = feat_list[i - 1]
return g
def raw(self, g):
use_eweight = False
if (self.eweight_name is not None) and self.eweight_name in g.edata:
use_eweight = True
feat_list = []
with g.local_scope():
if use_eweight:
message_func = fn.u_mul_e(self.in_feat_name, self.eweight_name, 'm')
else:
message_func = fn.copy_u(self.in_feat_name, 'm')
for _ in range(self.k):
g.update_all(message_func, fn.sum('m', self.in_feat_name))
feat_list.append(g.ndata[self.in_feat_name])
return feat_list
def rw(self, g):
use_eweight = False
if (self.eweight_name is not None) and self.eweight_name in g.edata:
use_eweight = True
feat_list = []
with g.local_scope():
g.ndata['h'] = g.ndata[self.in_feat_name]
if use_eweight:
message_func = fn.u_mul_e('h', self.eweight_name, 'm')
reduce_func = fn.sum('m', 'h')
# Compute the diagonal entries of D from the weighted A
g.update_all(fn.copy_e(self.eweight_name, 'm'), fn.sum('m', 'z'))
else:
message_func = fn.copy_u('h', 'm')
reduce_func = fn.mean('m', 'h')
for _ in range(self.k):
g.update_all(message_func, reduce_func)
if use_eweight:
g.ndata['h'] = g.ndata['h'] / F.reshape(g.ndata['z'], (g.num_nodes(), 1))
feat_list.append(g.ndata['h'])
return feat_list
def gcn(self, g):
feat_list = []
with g.local_scope():
if self.eweight_name is None:
eweight_name = 'w'
if eweight_name in g.edata:
g.edata.pop(eweight_name)
else:
eweight_name = self.eweight_name
transform = GCNNorm(eweight_name=eweight_name)
transform(g)
for _ in range(self.k):
g.update_all(fn.u_mul_e(self.in_feat_name, eweight_name, 'm'),
fn.sum('m', self.in_feat_name))
feat_list.append(g.ndata[self.in_feat_name])
return feat_list
def ppr(self, g):
feat_list = []
with g.local_scope():
if self.eweight_name is None:
eweight_name = 'w'
if eweight_name in g.edata:
g.edata.pop(eweight_name)
else:
eweight_name = self.eweight_name
transform = GCNNorm(eweight_name=eweight_name)
transform(g)
in_feat = g.ndata[self.in_feat_name]
for _ in range(self.k):
g.update_all(fn.u_mul_e(self.in_feat_name, eweight_name, 'm'),
fn.sum('m', self.in_feat_name))
g.ndata[self.in_feat_name] = (1 - self.alpha) * g.ndata[self.in_feat_name] +\
self.alpha * in_feat
feat_list.append(g.ndata[self.in_feat_name])
return feat_list