HeteroEmbedding

class dgl.nn.pytorch.HeteroEmbedding(num_embeddings, embedding_dim)[source]

Bases: torch.nn.modules.module.Module

Create a heterogeneous embedding table.

It internally contains multiple torch.nn.Embedding with different dictionary sizes.

Parameters
  • num_embeddings (dict[key, int]) – Size of the dictionaries. A key can be a string or a tuple of strings.

  • embedding_dim (int) – Size of each embedding vector.

Examples

>>> import dgl
>>> import torch
>>> from dgl.nn import HeteroEmbedding
>>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4)
>>> # Get the heterogeneous embedding table
>>> embeds = layer.weight
>>> print(embeds['user'].shape)
torch.Size([2, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([3, 4])
>>> # Get the embeddings for a subset
>>> input_ids = {'user': torch.LongTensor([0]),
...              ('user', 'follows', 'user'): torch.LongTensor([0, 2])}
>>> embeds = layer(input_ids)
>>> print(embeds['user'].shape)
torch.Size([1, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([2, 4])
forward(input_ids)[source]

Forward function

Parameters

input_ids (dict[key, Tensor]) – The row IDs to retrieve embeddings. It maps a key to key-specific IDs.

Returns

The retrieved embeddings.

Return type

dict[key, Tensor]