Bases: `BaseTransform`

Randomly mask columns of the node and edge feature tensors, as described in Graph Contrastive Learning with Augmentations.

Parameters:
• p (float, optional) β Probability of masking a column of a feature tensor. Default: 0.5.

• node_feat_names (list[str], optional) β The names of the node feature tensors to be masked. Default: None, which will not mask any node feature tensor.

• edge_feat_names (list[str], optional) β The names of the edge features to be masked. Default: None, which will not mask any edge feature tensor.

Example

The following example uses PyTorch backend.

```>>> import dgl
>>> import torch
```

Case1 : Mask node and edge feature tensors of a homogeneous graph.

```>>> transform = FeatMask(node_feat_names=['h'], edge_feat_names=['w'])
>>> g = dgl.rand_graph(5, 10)
>>> g.ndata['h'] = torch.ones((g.num_nodes(), 10))
>>> g.edata['w'] = torch.ones((g.num_edges(), 10))
```
```>>> g = transform(g)
>>> print(g.ndata['h'])
tensor([[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.]])
>>> print(g.edata['w'])
tensor([[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.]])
```

Case2 : Mask node and edge feature tensors of a heterogeneous graph.

```>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),
...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))
... })
>>> g.ndata['h'] = {'game': torch.ones(2, 5), 'player': torch.ones(3, 5)}
>>> g.edata['w'] = {('user', 'follows', 'user'): torch.ones(2, 5)}
>>> print(g.ndata['h']['game'])
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
>>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
>>> g = transform(g)
>>> print(g.ndata['h']['game'])
tensor([[1., 1., 0., 1., 0.],
[1., 1., 0., 1., 0.]])
>>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[0., 1., 0., 1., 0.],
[0., 1., 0., 1., 0.]])
```