DeepWalkΒΆ

class dgl.nn.pytorch.DeepWalk(g, emb_dim=128, walk_length=40, window_size=5, neg_weight=1, negative_size=5, fast_neg=True, sparse=True)[source]ΒΆ

Bases: torch.nn.modules.module.Module

DeepWalk module from DeepWalk: Online Learning of Social Representations

For a graph, it learns the node representations from scratch by maximizing the similarity of node pairs that are nearby (positive node pairs) and minimizing the similarity of other random node pairs (negative node pairs).

Parameters
  • g (DGLGraph) – Graph for learning node embeddings

  • emb_dim (int, optional) – Size of each embedding vector. Default: 128

  • walk_length (int, optional) – Number of nodes in a random walk sequence. Default: 40

  • window_size (int, optional) – In a random walk w, a node w[j] is considered close to a node w[i] if i - window_size <= j <= i + window_size. Default: 5

  • neg_weight (float, optional) – Weight of the loss term for negative samples in the total loss. Default: 1.0

  • negative_size (int, optional) – Number of negative samples to use for each positive sample. Default: 5

  • fast_neg (bool, optional) – If True, it samples negative node pairs within a batch of random walks. Default: True

  • sparse (bool, optional) – If True, gradients with respect to the learnable weights will be sparse. Default: True

node_embedΒΆ

Embedding table of the nodes

Type

nn.Embedding

Examples

>>> import torch
>>> from dgl.data import CoraGraphDataset
>>> from dgl.nn import DeepWalk
>>> from torch.optim import SparseAdam
>>> from torch.utils.data import DataLoader
>>> from sklearn.linear_model import LogisticRegression
>>> dataset = CoraGraphDataset()
>>> g = dataset[0]
>>> model = DeepWalk(g)
>>> dataloader = DataLoader(torch.arange(g.num_nodes()), batch_size=128,
...                         shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.01)
>>> num_epochs = 5
>>> for epoch in range(num_epochs):
...     for batch_walk in dataloader:
...         loss = model(batch_walk)
...         optimizer.zero_grad()
...         loss.backward()
...         optimizer.step()
>>> train_mask = g.ndata['train_mask']
>>> test_mask = g.ndata['test_mask']
>>> X = model.node_embed.weight.detach()
>>> y = g.ndata['label']
>>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy())
>>> clf.score(X[test_mask].numpy(), y[test_mask].numpy())
forward(batch_walk)[source]ΒΆ

Compute the loss for the batch of random walks

Parameters

batch_walk (torch.Tensor) – Random walks in the form of node ID sequences. The Tensor is of shape (batch_size, walk_length).

Returns

Loss value

Return type

torch.Tensor

reset_parameters()[source]ΒΆ

Reinitialize learnable parameters