# dgl.sort_csr_by_tag¶

dgl.sort_csr_by_tag(g, tag, tag_offset_name='_TAG_OFFSET')[source]

Return a new graph whose CSR matrix is sorted by the given tag.

Sort the internal CSR matrix of the graph so that the adjacency list of each node , which contains the out-edges, is sorted by the tag of the out-neighbors. After sorting, edges sharing the same tag will be arranged in a consecutive range in a node’s adjacency list. Following is an example:

Consider a graph as follows:

0 -> 0, 1, 2, 3, 4
1 -> 0, 1, 2


Given node tags [1, 1, 0, 2, 0], each node’s adjacency list will be sorted as follows:

0 -> 2, 4, 0, 1, 3
1 -> 2, 0, 1


The function will also returns the starting offsets of the tag segments in a tensor of shape $$(N, max\_tag+2)$$. For node i, its out-edges connecting to node tag j is stored between tag_offsets[i][j] ~ tag_offsets[i][j+1]. Since the offsets can be viewed node data, we store it in the ndata of the returned graph. Users can specify the ndata name by the tag_pos_name argument.

Note that the function will not change the edge ID neither how the edge features are stored. The input graph must allow CSR format. The graph must be on CPU.

If the input graph is heterogenous, it must have only one edge type and two node types (i.e., source and destination node types). In this case, the provided node tags are for the destination nodes, and the tag offsets are stored in the source node data.

The sorted graph and the calculated tag offsets are needed by certain operators that consider node tags. See sample_neighbors_biased() for an example.

Parameters
• g (DGLGraph) – The input graph.

• tag (Tensor) – Integer tensor of shape $$(N,)$$, $$N$$ being the number of (destination) nodes.

• tag_offset_name (str) – The name of the node feature to store tag offsets.

Returns

g_sorted – A new graph whose CSR is sorted. The node/edge features of the input graph is shallow-copied over.

• g_sorted.ndata[tag_offset_name] : Tensor of shape $$(N, max\_tag + 2)$$.

• If g is heterogeneous, get from g_sorted.srcdata.

Return type

DGLGraph

Examples

>>> g = dgl.graph(([0,0,0,0,0,1,1,1],[0,1,2,3,4,0,1,2]))
(array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32),
array([0, 1, 2, 3, 4, 0, 1, 2], dtype=int32))
>>> tag = torch.IntTensor([1,1,0,2,0])
>>> g_sorted = dgl.sort_csr_by_tag(g, tag)