RowFeatNormalizer¶

class dgl.transforms.RowFeatNormalizer(subtract_min=False, node_feat_names=None, edge_feat_names=None)[source]

Bases: dgl.transforms.module.BaseTransform

Row-normalizes the features given in node_feat_names and edge_feat_names.

The row normalization formular is:

$x = \frac{x}{\sum_i x_i}$

where $$x$$ denotes a row of the feature tensor.

Parameters
• subtract_min (bool) – If True, the minimum value of whole feature tensor will be subtracted before normalization. Default: False. Subtraction will make all values non-negative. If all values are negative, after normalisation, the sum of each row of the feature tensor will be 1.

• node_feat_names (list[str], optional) – The names of the node feature tensors to be row-normalized. Default: None, which will not normalize any node feature tensor.

• edge_feat_names (list[str], optional) – The names of the edge feature tensors to be row-normalized. Default: None, which will not normalize any edge feature tensor.

Example

The following example uses PyTorch backend.

>>> import dgl
>>> import torch
>>> from dgl import RowFeatNormalizer


Case1: Row normalize features of a homogeneous graph.

>>> transform = RowFeatNormalizer(subtract_min=True,
...                               node_feat_names=['h'], edge_feat_names=['w'])
>>> g = dgl.rand_graph(5, 20)
>>> g.ndata['h'] = torch.randn((g.num_nodes(), 5))
>>> g.edata['w'] = torch.randn((g.num_edges(), 5))
>>> g = transform(g)
>>> print(g.ndata['h'].sum(1))
tensor([1., 1., 1., 1., 1.])
>>> print(g.edata['w'].sum(1))
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.])


Case2: Row normalize features of a heterogeneous graph.

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),
...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))
... })
>>> g.ndata['h'] = {'game': torch.randn(2, 5), 'player': torch.randn(3, 5)}
>>> g.edata['w'] = {
...     ('user', 'follows', 'user'): torch.randn(2, 5),
...     ('player', 'plays', 'game'): torch.randn(2, 5)
... }
>>> g = transform(g)
>>> print(g.ndata['h']['game'].sum(1), g.ndata['h']['player'].sum(1))
tensor([1., 1.]) tensor([1., 1., 1.])
>>> print(g.edata['w'][('user', 'follows', 'user')].sum(1),
...     g.edata['w'][('player', 'plays', 'game')].sum(1))
tensor([1., 1.]) tensor([1., 1.])