TypedLinear¶
-
class
dgl.nn.pytorch.
TypedLinear
(in_size, out_size, num_types, regularizer=None, num_bases=None)[source]¶ Bases:
torch.nn.modules.module.Module
Linear transformation according to types.
For each sample of the input batch \(x \in X\), apply linear transformation \(xW_t\), where \(t\) is the type of \(x\).
The module supports two regularization methods (basis-decomposition and block-diagonal-decomposition) proposed by “Modeling Relational Data with Graph Convolutional Networks”
The basis regularization decomposes \(W_t\) by:
\[W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}\]where \(B\) is the number of bases, \(V_b^{(l)}\) are linearly combined with coefficients \(a_{tb}^{(l)}\).
The block-diagonal-decomposition regularization decomposes \(W_t\) into \(B\) block-diagonal matrices. We refer to \(B\) as the number of bases:
\[W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}\]where \(B\) is the number of bases, \(Q_{tb}^{(l)}\) are block bases with shape \(R^{(d^{(l+1)}/B)\times(d^{l}/B)}\).
- Parameters
in_size (int) – Input feature size.
out_size (int) – Output feature size.
num_types (int) – Total number of types.
regularizer (str, optional) –
Which weight regularizer to use “basis” or “bdd”:
”basis” is short for basis-decomposition.
”bdd” is short for block-diagonal-decomposition.
Default applies no regularization.
num_bases (int, optional) – Number of bases. Needed when
regularizer
is specified. Typically smaller thannum_types
. Default:None
.
Examples
No regularization.
>>> from dgl.nn import TypedLinear >>> import torch >>> >>> x = torch.randn(100, 32) >>> x_type = torch.randint(0, 5, (100,)) >>> m = TypedLinear(32, 64, 5) >>> y = m(x, x_type) >>> print(y.shape) torch.Size([100, 64])
With basis regularization
>>> x = torch.randn(100, 32) >>> x_type = torch.randint(0, 5, (100,)) >>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4) >>> y = m(x, x_type) >>> print(y.shape) torch.Size([100, 64])
-
forward
(x, x_type, sorted_by_type=False)[source]¶ Forward computation.
- Parameters
x (torch.Tensor) – A 2D input tensor. Shape: (N, D1)
x_type (torch.Tensor) – A 1D integer tensor storing the type of the elements in
x
with one-to-one correspondenc. Shape: (N,)sorted_by_type (bool, optional) – Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may be faster.
- Returns
y – The transformed output tensor. Shape: (N, D2)
- Return type
torch.Tensor