dgl.distributed

DGL distributed module contains classes and functions to support distributed Graph Neural Network training and inference on a cluster of machines.

This includes a few submodules:

  • distributed data structures including distributed graph, distributed tensor and distributed embeddings.

  • distributed sampling.

  • distributed workload split at runtime.

  • graph partition.

Initialization

initialize(ip_config[, max_queue_size, ...])

Initialize DGL's distributed module

Distributed Graph

class dgl.distributed.DistGraph(graph_name, gpb=None, part_config=None)[source]

The class for accessing a distributed graph.

This class provides a subset of DGLGraph APIs for accessing partitioned graph data in distributed GNN training and inference. Thus, its main use case is to work with distributed sampling APIs to generate mini-batches and perform forward and backward computation on the mini-batches.

The class can run in two modes: the standalone mode and the distributed mode.

  • When a user runs the training script normally, DistGraph will be in the standalone mode. In this mode, the input data must be constructed by partition_graph() with only one partition. This mode is used for testing and debugging purpose. In this mode, users have to provide part_config so that DistGraph can load the input graph.

  • When a user runs the training script with the distributed launch script, DistGraph will be set into the distributed mode. This is used for actual distributed training. All data of partitions are loaded by the DistGraph servers, which are created by DGL’s launch script. DistGraph connects with the servers to access the partitioned graph data.

Currently, the DistGraph servers and clients run on the same set of machines in the distributed mode. DistGraph uses shared-memory to access the partition data in the local machine. This gives the best performance for distributed training

Users may want to run DistGraph servers and clients on separate sets of machines. In this case, a user may want to disable shared memory by passing disable_shared_mem=False when creating DistGraphServer. When shared memory is disabled, a user has to pass a partition book.

Parameters:
  • graph_name (str) – The name of the graph. This name has to be the same as the one used for partitioning a graph in dgl.distributed.partition.partition_graph().

  • gpb (GraphPartitionBook, optional) – The partition book object. Normally, users do not need to provide the partition book. This argument is necessary only when users want to run server process and trainer processes on different machines.

  • part_config (str, optional) – The path of partition configuration file generated by dgl.distributed.partition.partition_graph(). It’s used in the standalone mode.

Examples

The example shows the creation of DistGraph in the standalone mode.

>>> dgl.distributed.partition_graph(g, 'graph_name', 1, num_hops=1, part_method='metis',
...                                 out_path='output/')
>>> g = dgl.distributed.DistGraph('graph_name', part_config='output/graph_name.json')

The example shows the creation of DistGraph in the distributed mode.

>>> g = dgl.distributed.DistGraph('graph-name')

The code below shows the mini-batch training using DistGraph.

>>> def sample(seeds):
...     seeds = th.LongTensor(np.asarray(seeds))
...     frontier = dgl.distributed.sample_neighbors(g, seeds, 10)
...     return dgl.to_block(frontier, seeds)
>>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000,
...                                             collate_fn=sample, shuffle=True)
>>> for block in dataloader:
...     feat = g.ndata['features'][block.srcdata[dgl.NID]]
...     labels = g.ndata['labels'][block.dstdata[dgl.NID]]
...     pred = model(block, feat)

Note

DGL’s distributed training by default runs server processes and trainer processes on the same set of machines. If users need to run them on different sets of machines, it requires manually setting up servers and trainers. The setup is not fully tested yet.

barrier()[source]

Barrier for all client nodes.

This API blocks the current process untill all the clients invoke this API. Please use this API with caution.

property device

Get the device context of this graph.

Examples

The following example uses PyTorch backend.

>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1])
... })
>>> print(g.device)
device(type='cpu')
>>> g = g.to('cuda:0')
>>> print(g.device)
device(type='cuda', index=0)
Return type:

Device context object

property edata

Return the data view of all the edges.

Returns:

The data view in the distributed graph storage.

Return type:

EdgeDataView

edge_attr_schemes()[source]

Return the edge feature schemes.

Each feature scheme is a named tuple that stores the shape and data type of the edge feature.

Returns:

The schemes of edge feature columns.

Return type:

dict of str to schemes

Examples

The following uses PyTorch backend.

>>> g.edge_attr_schemes()
{'h': Scheme(shape=(4,), dtype=torch.float32)}
property edges

Return an edge view

property etypes

Return the list of edge types of this graph.

Return type:

list of str

Examples

>>> g = DistGraph("test")
>>> g.etypes
['_E']
find_edges(edges, etype=None)[source]

Given an edge ID array, return the source and destination node ID array s and d. s[i] and d[i] are source and destination node ID for edge eid[i].

Parameters:
  • edges (Int Tensor) –

    Each element is an ID. The tensor must have the same device type

    and ID data type as the graph’s.

  • etype (str or (str, str, str), optional) –

    The type names of the edges. The allowed type name formats are:

    • (str, str, str) for source node type, edge type and destination node type.

    • or one str edge type name if the name can uniquely identify a triplet format in the graph.

    Can be omitted if the graph has only one type of edges.

Returns:

  • tensor – The source node ID array.

  • tensor – The destination node ID array.

get_edge_partition_policy(etype)[source]

Get the partition policy for an edge type.

When creating a new distributed tensor, we need to provide a partition policy that indicates how to distribute data of the distributed tensor in a cluster of machines. When we load a distributed graph in the cluster, we have pre-defined partition policies for each node type and each edge type. By providing the edge type, we can reference to the pre-defined partition policy for the edge type.

Parameters:

etype (str or (str, str, str)) – The edge type

Returns:

The partition policy for the edge type.

Return type:

PartitionPolicy

get_etype_id(etype)[source]

Return the id of the given edge type.

etype can also be None. If so, there should be only one edge type in the graph.

Parameters:

etype (str or tuple of str) – Edge type

Return type:

int

get_node_partition_policy(ntype)[source]

Get the partition policy for a node type.

When creating a new distributed tensor, we need to provide a partition policy that indicates how to distribute data of the distributed tensor in a cluster of machines. When we load a distributed graph in the cluster, we have pre-defined partition policies for each node type and each edge type. By providing the node type, we can reference to the pre-defined partition policy for the node type.

Parameters:

ntype (str) – The node type

Returns:

The partition policy for the node type.

Return type:

PartitionPolicy

get_ntype_id(ntype)[source]

Return the ID of the given node type.

ntype can also be None. If so, there should be only one node type in the graph.

Parameters:

ntype (str) – Node type

Return type:

int

get_partition_book()[source]

Get the partition information.

Returns:

Object that stores all graph partition information.

Return type:

GraphPartitionBook

property idtype

The dtype of graph index

Returns:

th.int32/th.int64 or tf.int32/tf.int64 etc.

Return type:

backend dtype object

See also

long, int

in_degrees(v='__ALL__')[source]

Return the in-degree(s) of the given nodes.

It computes the in-degree(s). It does not support heterogeneous graphs yet.

Parameters:

v (node IDs) –

The node IDs. The allowed formats are:

  • int: A single node.

  • Int Tensor: Each element is a node ID. The tensor must have the same device type and ID data type as the graph’s.

  • iterable[int]: Each element is a node ID.

If not given, return the in-degrees of all the nodes.

Returns:

The in-degree(s) of the node(s) in a Tensor. The i-th element is the in-degree of the i-th input node. If v is an int, return an int too.

Return type:

int or Tensor

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch

Query for all nodes.

>>> g.in_degrees()
tensor([0, 2, 1, 1])

Query for nodes 1 and 2.

>>> g.in_degrees(torch.tensor([1, 2]))
tensor([2, 1])

See also

out_degrees

property local_partition

Return the local partition on the client

DistGraph provides a global view of the distributed graph. Internally, it may contains a partition of the graph if it is co-located with the server. When servers and clients run on separate sets of machines, this returns None.

Returns:

The local partition

Return type:

DGLGraph

property ndata

Return the data view of all the nodes.

Returns:

The data view in the distributed graph storage.

Return type:

NodeDataView

node_attr_schemes()[source]

Return the node feature schemes.

Each feature scheme is a named tuple that stores the shape and data type of the node feature.

Returns:

The schemes of node feature columns.

Return type:

dict of str to schemes

Examples

The following uses PyTorch backend.

>>> g.node_attr_schemes()
{'h': Scheme(shape=(4,), dtype=torch.float32)}
property nodes

Return a node view

property ntypes

Return the list of node types of this graph.

Return type:

list of str

Examples

>>> g = DistGraph("test")
>>> g.ntypes
['_U']
num_edges(etype=None)[source]

Return the total number of edges in the distributed graph.

Parameters:

etype (str or (str, str, str), optional) –

The type name of the edges. The allowed type name formats are:

  • (str, str, str) for source node type, edge type and destination node type.

  • or one str edge type name if the name can uniquely identify a triplet format in the graph.

If not provided, return the total number of edges regardless of the types in the graph.

Returns:

The number of edges

Return type:

int

Examples

>>> g = dgl.distributed.DistGraph('ogb-product')
>>> print(g.num_edges())
123718280
num_nodes(ntype=None)[source]

Return the total number of nodes in the distributed graph.

Parameters:

ntype (str, optional) – The node type name. If given, it returns the number of nodes of the type. If not given (default), it returns the total number of nodes of all types.

Returns:

The number of nodes

Return type:

int

Examples

>>> g = dgl.distributed.DistGraph('ogb-product')
>>> print(g.num_nodes())
2449029
number_of_edges(etype=None)[source]

Alias of num_edges()

number_of_nodes(ntype=None)[source]

Alias of num_nodes()

out_degrees(u='__ALL__')[source]

Return the out-degree(s) of the given nodes.

It computes the out-degree(s). It does not support heterogeneous graphs yet.

Parameters:

u (node IDs) –

The node IDs. The allowed formats are:

  • int: A single node.

  • Int Tensor: Each element is a node ID. The tensor must have the same device type and ID data type as the graph’s.

  • iterable[int]: Each element is a node ID.

If not given, return the in-degrees of all the nodes.

Returns:

The out-degree(s) of the node(s) in a Tensor. The i-th element is the out-degree of the i-th input node. If v is an int, return an int too.

Return type:

int or Tensor

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch

Query for all nodes.

>>> g.out_degrees()
tensor([2, 2, 0, 0])

Query for nodes 1 and 2.

>>> g.out_degrees(torch.tensor([1, 2]))
tensor([2, 0])

See also

in_degrees

rank()[source]

The rank of the current DistGraph.

This returns a unique number to identify the DistGraph object among all of the client processes.

Returns:

The rank of the current DistGraph.

Return type:

int

Distributed Tensor

class dgl.distributed.DistTensor(shape, dtype, name=None, init_func=None, part_policy=None, persistent=False, is_gdata=True, attach=True)[source]

Distributed tensor.

DistTensor references to a distributed tensor sharded and stored in a cluster of machines. It has the same interface as Pytorch Tensor to access its metadata (e.g., shape and data type). To access data in a distributed tensor, it supports slicing rows and writing data to rows. It does not support any operators of a deep learning framework, such as addition and multiplication.

Currently, distributed tensors are designed to store node data and edge data of a distributed graph. Therefore, their first dimensions have to be the number of nodes or edges in the graph. The tensors are sharded in the first dimension based on the partition policy of nodes or edges. When a distributed tensor is created, the partition policy is automatically determined based on the first dimension if the partition policy is not provided. If the first dimension matches the number of nodes of a node type, DistTensor will use the partition policy for this particular node type; if the first dimension matches the number of edges of an edge type, DistTensor will use the partition policy for this particular edge type. If DGL cannot determine the partition policy automatically (e.g., multiple node types or edge types have the same number of nodes or edges), users have to explicity provide the partition policy.

A distributed tensor can be ether named or anonymous. When a distributed tensor has a name, the tensor can be persistent if persistent=True. Normally, DGL destroys the distributed tensor in the system when the DistTensor object goes away. However, a persistent tensor lives in the system even if the DistTenor object disappears in the trainer process. The persistent tensor has the same life span as the DGL servers. DGL does not allow an anonymous tensor to be persistent.

When a DistTensor object is created, it may reference to an existing distributed tensor or create a new one. A distributed tensor is identified by the name passed to the constructor. If the name exists, DistTensor will reference the existing one. In this case, the shape and the data type must match the existing tensor. If the name doesn’t exist, a new tensor will be created in the kvstore.

When a distributed tensor is created, its values are initialized to zero. Users can define an initialization function to control how the values are initialized. The init function has two input arguments: shape and data type and returns a tensor. Below shows an example of an init function:

def init_func(shape, dtype):
    return torch.ones(shape=shape, dtype=dtype)
Parameters:
  • shape (tuple) – The shape of the tensor. The first dimension has to be the number of nodes or the number of edges of a distributed graph.

  • dtype (dtype) – The dtype of the tensor. The data type has to be the one in the deep learning framework.

  • name (string, optional) – The name of the embeddings. The name can uniquely identify embeddings in a system so that another DistTensor object can referent to the distributed tensor.

  • init_func (callable, optional) – The function to initialize data in the tensor. If the init function is not provided, the values of the embeddings are initialized to zero.

  • part_policy (PartitionPolicy, optional) – The partition policy of the rows of the tensor to different machines in the cluster. Currently, it only supports node partition policy or edge partition policy. The system determines the right partition policy automatically.

  • persistent (bool) – Whether the created tensor lives after the DistTensor object is destroyed.

  • is_gdata (bool) – Whether the created tensor is a ndata/edata or not.

  • attach (bool) – Whether to attach group ID into name to be globally unique.

Examples

>>> init = lambda shape, dtype: th.ones(shape, dtype=dtype)
>>> arr = dgl.distributed.DistTensor((g.num_nodes(), 2), th.int32, init_func=init)
>>> print(arr[0:3])
tensor([[1, 1],
        [1, 1],
        [1, 1]], dtype=torch.int32)
>>> arr[0:3] = th.ones((3, 2), dtype=th.int32) * 2
>>> print(arr[0:3])
tensor([[2, 2],
        [2, 2],
        [2, 2]], dtype=torch.int32)

Note

The creation of DistTensor is a synchronized operation. When a trainer process tries to create a DistTensor object, the creation succeeds only when all trainer processes do the same.

property dtype

Return the data type of the distributed tensor.

Returns:

The data type of the tensor.

Return type:

dtype

property name

Return the name of the distributed tensor

Returns:

The name of the tensor.

Return type:

str

property part_policy

Return the partition policy

Returns:

The partition policy of the distributed tensor.

Return type:

PartitionPolicy

property shape

Return the shape of the distributed tensor.

Returns:

The shape of the distributed tensor.

Return type:

tuple

Distributed Node Embedding

class dgl.distributed.DistEmbedding(num_embeddings, embedding_dim, name=None, init_func=None, part_policy=None)[source]

Distributed node embeddings.

DGL provides a distributed embedding to support models that require learnable embeddings. DGL’s distributed embeddings are mainly used for learning node embeddings of graph models. Because distributed embeddings are part of a model, they are updated by mini-batches. The distributed embeddings have to be updated by DGL’s optimizers instead of the optimizers provided by the deep learning frameworks (e.g., Pytorch and MXNet).

To support efficient training on a graph with many nodes, the embeddings support sparse updates. That is, only the embeddings involved in a mini-batch computation are updated. Please refer to Distributed Optimizers for available optimizers in DGL.

Distributed embeddings are sharded and stored in a cluster of machines in the same way as dgl.distributed.DistTensor, except that distributed embeddings are trainable. Because distributed embeddings are sharded in the same way as nodes and edges of a distributed graph, it is usually much more efficient to access than the sparse embeddings provided by the deep learning frameworks.

Parameters:
  • num_embeddings (int) – The number of embeddings. Currently, the number of embeddings has to be the same as the number of nodes or the number of edges.

  • embedding_dim (int) – The dimension size of embeddings.

  • name (str, optional) – The name of the embeddings. The name can uniquely identify embeddings in a system so that another DistEmbedding object can referent to the same embeddings.

  • init_func (callable, optional) – The function to create the initial data. If the init function is not provided, the values of the embeddings are initialized to zero.

  • part_policy (PartitionPolicy, optional) – The partition policy that assigns embeddings to different machines in the cluster. Currently, it only supports node partition policy or edge partition policy. The system determines the right partition policy automatically.

Examples

>>> def initializer(shape, dtype):
        arr = th.zeros(shape, dtype=dtype)
        arr.uniform_(-1, 1)
        return arr
>>> emb = dgl.distributed.DistEmbedding(g.num_nodes(), 10, init_func=initializer)
>>> optimizer = dgl.distributed.optim.SparseAdagrad([emb], lr=0.001)
>>> for blocks in dataloader:
...     feats = emb(nids)
...     loss = F.sum(feats + 1, 0)
...     loss.backward()
...     optimizer.step()

Note

When a DistEmbedding object is used in the forward computation, users have to invoke step() afterwards. Otherwise, there will be some memory leak.

Distributed embedding optimizer

class dgl.distributed.optim.SparseAdagrad(params, lr, eps=1e-10)[source]

Distributed Node embedding optimizer using the Adagrad algorithm.

This optimizer implements a distributed sparse version of Adagrad algorithm for optimizing dgl.distributed.DistEmbedding. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings.

Adagrad maintains a \(G_{t,i,j}\) for every parameter in the embeddings, where \(G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2\) and \(g_{t,i,j}\) is the gradient of the dimension \(j\) of embedding \(i\) at step \(t\).

NOTE: The support of sparse Adagrad optimizer is experimental.

Parameters:
  • params (list[dgl.distributed.DistEmbedding]) – The list of dgl.distributed.DistEmbedding.

  • lr (float) – The learning rate.

  • eps (float, Optional) – The term added to the denominator to improve numerical stability Default: 1e-10

load(f)

Load the local state of the optimizer from the file on per rank.

NOTE: This needs to be called on all ranks.

Parameters:

f (Union[str, os.PathLike]) – The path of the file to load from.

See also

save

save(f)

Save the local state_dict to disk on per rank.

Saved dict contains 2 parts:

  • ‘params’: hyper parameters of the optimizer.

  • ‘emb_states’: partial optimizer states, each embedding contains 2 items:
    1. `ids`: global id of the nodes/edges stored in this rank.

    2. `states`: state data corrseponding to `ids`.

NOTE: This needs to be called on all ranks.

Parameters:

f (Union[str, os.PathLike]) – The path of the file to save to.

See also

load

step()

The step function.

The step function is invoked at the end of every batch to push the gradients of the embeddings involved in a mini-batch to DGL’s servers and update the embeddings.

class dgl.distributed.optim.SparseAdam(params, lr, betas=(0.9, 0.999), eps=1e-08)[source]

Distributed Node embedding optimizer using the Adam algorithm.

This optimizer implements a distributed sparse version of Adam algorithm for optimizing dgl.distributed.DistEmbedding. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings.

Adam maintains a \(Gm_{t,i,j}\) and Gp_{t,i,j} for every parameter in the embeddings, where \(Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}\), \(Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2\), \(g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \sqrt{Gp_{t,i,j} / (1 - beta2^t)}\) and \(g_{t,i,j}\) is the gradient of the dimension \(j\) of embedding \(i\) at step \(t\).

NOTE: The support of sparse Adam optimizer is experimental.

Parameters:
  • params (list[dgl.distributed.DistEmbedding]) – The list of dgl.distributed.DistEmbedding.

  • lr (float) – The learning rate.

  • betas (tuple[float, float], Optional) – Coefficients used for computing running averages of gradient and its square. Default: (0.9, 0.999)

  • eps (float, Optional) – The term added to the denominator to improve numerical stability Default: 1e-8

load(f)

Load the local state of the optimizer from the file on per rank.

NOTE: This needs to be called on all ranks.

Parameters:

f (Union[str, os.PathLike]) – The path of the file to load from.

See also

save

save(f)

Save the local state_dict to disk on per rank.

Saved dict contains 2 parts:

  • ‘params’: hyper parameters of the optimizer.

  • ‘emb_states’: partial optimizer states, each embedding contains 2 items:
    1. `ids`: global id of the nodes/edges stored in this rank.

    2. `states`: state data corrseponding to `ids`.

NOTE: This needs to be called on all ranks.

Parameters:

f (Union[str, os.PathLike]) – The path of the file to save to.

See also

load

step()

The step function.

The step function is invoked at the end of every batch to push the gradients of the embeddings involved in a mini-batch to DGL’s servers and update the embeddings.

Distributed workload split

node_split(nodes[, partition_book, ntype, ...])

Split nodes and return a subset for the local rank.

edge_split(edges[, partition_book, etype, ...])

Split edges and return a subset for the local rank.

Distributed Sampling

Distributed DataLoader

class dgl.distributed.DistDataLoader(dataset, batch_size, shuffle=False, collate_fn=None, drop_last=False, queue_size=None)[source]

DGL customized multiprocessing dataloader.

DistDataLoader provides a similar interface to Pytorch’s DataLoader to generate mini-batches with multiprocessing. It utilizes the worker processes created by dgl.distributed.initialize() to parallelize sampling.

Parameters:
  • dataset (a tensor) – Tensors of node IDs or edge IDs.

  • batch_size (int) – The number of samples per batch to load.

  • shuffle (bool, optional) – Set to True to have the data reshuffled at every epoch (default: False).

  • collate_fn (callable, optional) – The function is typically used to sample neighbors of the nodes in a batch or the endpoint nodes of the edges in a batch.

  • drop_last (bool, optional) – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

  • queue_size (int, optional) – Size of multiprocessing queue

Examples

>>> g = dgl.distributed.DistGraph('graph-name')
>>> def sample(seeds):
...     seeds = th.LongTensor(np.asarray(seeds))
...     frontier = dgl.distributed.sample_neighbors(g, seeds, 10)
...     return dgl.to_block(frontier, seeds)
>>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000,
                                                collate_fn=sample, shuffle=True)
>>> for block in dataloader:
...     feat = g.ndata['features'][block.srcdata[dgl.NID]]
...     labels = g.ndata['labels'][block.dstdata[dgl.NID]]
...     pred = model(block, feat)

Note

When performing DGL’s distributed sampling with multiprocessing, users have to use this class instead of Pytorch’s DataLoader because DGL’s RPC requires that all processes establish connections with servers before invoking any DGL’s distributed API. Therefore, this dataloader uses the worker processes created in dgl.distributed.initialize().

Note

This dataloader does not guarantee the iteration order. For example, if dataset = [1, 2, 3, 4], batch_size = 2 and shuffle = False, the order of [1, 2] and [3, 4] is not guaranteed.

Distributed Graph Sampling Operators

sample_neighbors(g, nodes, fanout[, ...])

Sample from the neighbors of the given nodes from a distributed graph.

sample_etype_neighbors(g, nodes, fanout[, ...])

Sample from the neighbors of the given nodes from a distributed graph.

find_edges(g, edge_ids)

Given an edge ID array, return the source and destination node ID array s and d from a distributed graph.

in_subgraph(g, nodes)

Return the subgraph induced on the inbound edges of the given nodes.

Partition

Graph partition book

class dgl.distributed.GraphPartitionBook[source]

The base class of the graph partition book.

For distributed training, a graph is partitioned into multiple parts and is loaded in multiple machines. The partition book contains all necessary information to locate nodes and edges in the cluster.

The partition book contains various partition information, including

  • the number of partitions,

  • the partition ID that a node or edge belongs to,

  • the node IDs and the edge IDs that a partition has.

  • the local IDs of nodes and edges in a partition.

Currently, only one class that implement GraphPartitionBook :RangePartitionBook. It calculates the mapping between node/edge IDs and partition IDs based on some small metadata because nodes/edges have been relabeled to have IDs in the same partition fall in a contiguous ID range.

A graph partition book is constructed automatically when a graph is partitioned. When a graph partition is loaded, a graph partition book is loaded as well. Please see partition_graph(), load_partition() and load_partition_book() for more details.

property canonical_etypes

Get the list of canonical edge types

Returns:

A list of canonical etypes

Return type:

list[(str, str, str)]

eid2localeid(eids, partid, etype)[source]

Get the local edge ids within the given partition.

Parameters:
  • eids (tensor) – global edge IDs

  • partid (int) – partition ID

  • etype (str or (str, str, str)) – The edge type

Returns:

local edge IDs

Return type:

tensor

eid2partid(eids, etype)[source]

From global edge IDs to partition IDs

Parameters:
  • eids (tensor) – global edge IDs

  • etype (str or (str, str, str)) – The edge type

Returns:

partition IDs

Return type:

tensor

map_to_homo_eid(ids, etype)[source]

Map type-wise edge IDs and type IDs to homogeneous edge IDs.

Parameters:
  • ids (tensor) – Type-wise edge Ids

  • etype (str or (str, str, str)) – The edge type

Returns:

Homogeneous edge IDs.

Return type:

Tensor

map_to_homo_nid(ids, ntype)[source]

Map type-wise node IDs and type IDs to homogeneous node IDs.

Parameters:
  • ids (tensor) – Type-wise node Ids

  • ntype (str) – node type

Returns:

Homogeneous node IDs.

Return type:

Tensor

map_to_per_etype(ids)[source]

Map homogeneous edge IDs to type-wise IDs and edge types.

Parameters:

ids (tensor) – Homogeneous edge IDs.

Returns:

edge type IDs and type-wise edge IDs.

Return type:

(tensor, tensor)

map_to_per_ntype(ids)[source]

Map homogeneous node IDs to type-wise IDs and node types.

Parameters:

ids (tensor) – Homogeneous node IDs.

Returns:

node type IDs and type-wise node IDs.

Return type:

(tensor, tensor)

metadata()[source]

Return the partition meta data.

The meta data includes:

  • The machine ID.

  • Number of nodes and edges of each partition.

Examples

>>> print(g.get_partition_book().metadata())
>>> [{'machine_id' : 0, 'num_nodes' : 3000, 'num_edges' : 5000},
...  {'machine_id' : 1, 'num_nodes' : 2000, 'num_edges' : 4888},
...  ...]
Returns:

Meta data of each partition.

Return type:

list[dict[str, any]]

nid2localnid(nids, partid, ntype)[source]

Get local node IDs within the given partition.

Parameters:
  • nids (tensor) – global node IDs

  • partid (int) – partition ID

  • ntype (str) – The node type

Returns:

local node IDs

Return type:

tensor

nid2partid(nids, ntype)[source]

From global node IDs to partition IDs

Parameters:
  • nids (tensor) – global node IDs

  • ntype (str) – The node type

Returns:

partition IDs

Return type:

tensor

num_partitions()[source]

Return the number of partitions.

Returns:

number of partitions

Return type:

int

property partid

Get the current partition ID

Returns:

The partition ID of current machine

Return type:

int

partid2eids(partid, etype)[source]

From partition id to global edge IDs

Parameters:
  • partid (int) – partition id

  • etype (str or (str, str, str)) – The edge type

Returns:

edge IDs

Return type:

tensor

partid2nids(partid, ntype)[source]

From partition id to global node IDs

Parameters:
  • partid (int) – partition id

  • ntype (str) – The node type

Returns:

node IDs

Return type:

tensor

shared_memory(graph_name)[source]

Move the partition book to shared memory.

Parameters:

graph_name (str) – The graph name. This name will be used to read the partition book from shared memory in another process.

class dgl.distributed.PartitionPolicy(policy_str, partition_book)[source]

This defines a partition policy for a distributed tensor or distributed embedding.

When DGL shards tensors and stores them in a cluster of machines, it requires partition policies that map rows of the tensors to machines in the cluster.

Although an arbitrary partition policy can be defined, DGL currently supports two partition policies for mapping nodes and edges to machines. To define a partition policy from a graph partition book, users need to specify the policy name (‘node’ or ‘edge’).

Parameters:
  • policy_str (str) – Partition policy name, e.g., ‘edge~_N:_E:_N’ or ‘node~_N’.

  • partition_book (GraphPartitionBook) – A graph partition book

get_part_size()[source]

Get data size of current partition.

Returns:

data size

Return type:

int

get_size()[source]

Get the full size of the data.

Returns:

data size

Return type:

int

property part_id

Get partition ID

Returns:

The partition ID

Return type:

int

property partition_book

Get partition book

Returns:

The graph partition book

Return type:

GraphPartitionBook

property policy_str

Get the policy name

Returns:

The name of the partition policy.

Return type:

str

to_local(id_tensor)[source]

Mapping global ID to local ID.

Parameters:

id_tensor (tensor) – Gloabl ID tensor

Returns:

local ID tensor

Return type:

tensor

to_partid(id_tensor)[source]

Mapping global ID to partition ID.

Parameters:

id_tensor (tensor) – Global ID tensor

Returns:

partition ID

Return type:

tensor

Split and Load Partitions

load_partition(part_config, part_id[, ...])

Load data of a partition from the data path.

load_partition_feats(part_config, part_id[, ...])

Load node/edge feature data from a partition.

load_partition_book(part_config, part_id)

Load a graph partition book from the partition config file.

partition_graph(g, graph_name, num_parts, ...)

Partition a graph for distributed training and store the partitions on files.