Source code for dgl.init

"""Module for common feature initializers."""
from __future__ import absolute_import

from . import backend as F

__all__ = ['base_initializer', 'zero_initializer']

[docs]def base_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument """The function signature for feature initializer. Any customized feature initializer should follow this signature (see example below). Parameters ---------- shape : tuple of int The shape of the result features. The first dimension is the batch dimension. dtype : data type object The data type of the returned features. ctx : context object The device context of the returned features. id_range : slice The start id and the end id of the features to be initialized. The id could be node or edge id depending on the scenario. Note that the step is always None. Examples -------- If PyTorch is used as backend, the following code defines an feature initializer that initializes tensor value to 1 >>> import torch >>> import dgl >>> def initializer(shape, dtype, ctx, id_range): >>> return torch.ones(shape, dtype=dtype, device=ctx) >>> g = dgl.DGLGraph() >>> g.set_n_initializer(initializer) See Also -------- dgl.DGLGraph.set_n_initializer dgl.DGLGraph.set_e_initializer """ raise NotImplementedError
[docs]def zero_initializer(shape, dtype, ctx, id_range): # pylint: disable=unused-argument """Zero feature initializer Examples -------- >>> import dgl >>> g = dgl.DGLGraph() >>> g.set_n_initializer(dgl.init.zero_initializer) See Also -------- dgl.DGLGraph.set_n_initializer dgl.DGLGraph.set_e_initializer """ return F.zeros(shape, dtype, ctx)