dgl.unbatchΒΆ
-
dgl.
unbatch
(g, node_split=None, edge_split=None)[source]ΒΆ Revert the batch operation by split the given graph into a list of small ones.
This is the reverse operation of :func:
dgl.batch
. If thenode_split
or theedge_split
is not given, it callsDGLGraph.batch_num_nodes()
andDGLGraph.batch_num_edges()
of the input graph to get the information.If the
node_split
or theedge_split
arguments are given, it will partition the graph according to the given segments. One must assure that the partition is valid β edges of the i^th graph only connect nodes belong to the i^th graph. Otherwise, DGL will throw an error.The function supports heterograph input, in which case the two split section arguments shall be of dictionary type β similar to the
DGLGraph.batch_num_nodes()
andDGLGraph.batch_num_edges()
attributes of a heterograph.- Parameters
- Returns
Unbatched list of graphs.
- Return type
Examples
Unbatch a batched graph
>>> import dgl >>> import torch as th >>> # 4 nodes, 3 edges >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> # 3 nodes, 4 edges >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> # add features >>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3) >>> g1.edata['w'] = th.ones(g1.num_edges(), 2) >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3) >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2) >>> bg = dgl.batch([g1, g2]) >>> f1, f2 = dgl.unbatch(bg) >>> f1 Graph(num_nodes=4, num_edges=3, ndata_schemes={βxβ : Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={βwβ : Scheme(shape=(2,), dtype=torch.float32)}) >>> f2 Graph(num_nodes=3, num_edges=4, ndata_schemes={βxβ : Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={βwβ : Scheme(shape=(2,), dtype=torch.float32)})
With provided split arguments:
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> g3 = dgl.graph((th.tensor([0]), th.tensor([1]))) >>> bg = dgl.batch([g1, g2, g3]) >>> bg.batch_num_nodes() tensor([4, 3, 2]) >>> bg.batch_num_edges() tensor([3, 4, 1]) >>> # unbatch but merge g2 and g3 >>> f1, f2 = dgl.unbatch(bg, th.tensor([4, 5]), th.tensor([3, 5])) >>> f1 Graph(num_nodes=4, num_edges=3, ndata_schemes={} edata_schemes={}) >>> f2 Graph(num_nodes=5, num_edges=5, ndata_schemes={} edata_schemes={})
Heterograph input
>>> hg1 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))}) >>> hg2 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))}) >>> bhg = dgl.batch([hg1, hg2]) >>> f1, f2 = dgl.unbatch(bhg) >>> f1 Graph(num_nodes={'user': 2, 'game': 1}, num_edges={('user', 'plays', 'game'): 2}, metagraph=[('drug', 'game')]) >>> f2 Graph(num_nodes={'user': 1, 'game': 3}, num_edges={('user', 'plays', 'game'): 3}, metagraph=[('drug', 'game')])
See also