FeatMask

class dgl.transforms.FeatMask(p=0.5, node_feat_names=None, edge_feat_names=None)[source]

Bases: BaseTransform

Randomly mask columns of the node and edge feature tensors, as described in Graph Contrastive Learning with Augmentations.

Parameters:
  • p (float, optional) – Probability of masking a column of a feature tensor. Default: 0.5.

  • node_feat_names (list[str], optional) – The names of the node feature tensors to be masked. Default: None, which will not mask any node feature tensor.

  • edge_feat_names (list[str], optional) – The names of the edge features to be masked. Default: None, which will not mask any edge feature tensor.

Example

The following example uses PyTorch backend.

>>> import dgl
>>> import torch
>>> from dgl import FeatMask

Case1 : Mask node and edge feature tensors of a homogeneous graph.

>>> transform = FeatMask(node_feat_names=['h'], edge_feat_names=['w'])
>>> g = dgl.rand_graph(5, 10)
>>> g.ndata['h'] = torch.ones((g.num_nodes(), 10))
>>> g.edata['w'] = torch.ones((g.num_edges(), 10))
>>> g = transform(g)
>>> print(g.ndata['h'])
tensor([[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.]])
>>> print(g.edata['w'])
tensor([[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.]])

Case2 : Mask node and edge feature tensors of a heterogeneous graph.

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),
...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))
... })
>>> g.ndata['h'] = {'game': torch.ones(2, 5), 'player': torch.ones(3, 5)}
>>> g.edata['w'] = {('user', 'follows', 'user'): torch.ones(2, 5)}
>>> print(g.ndata['h']['game'])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
>>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
>>> g = transform(g)
>>> print(g.ndata['h']['game'])
tensor([[1., 1., 0., 1., 0.],
        [1., 1., 0., 1., 0.]])
>>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[0., 1., 0., 1., 0.],
        [0., 1., 0., 1., 0.]])