#### Note

**Definition**: :func:`~dgl.batch` unions a list of $B$\n :class:`~dgl.DGLGraph`\\ s and returns a :class:`~dgl.DGLGraph` of batch \n size $B$. \n\n - The union includes all the nodes,\n edges, and their features. The order of nodes, edges, and features are\n preserved. \n\n - Given that you have $V_i$ nodes for graph\n $\\mathcal{G}_i$, the node ID $j$ in graph\n $\\mathcal{G}_i$ correspond to node ID\n $j + \\sum_{k=1}^{i-1} V_k$ in the batched graph. \n\n - Therefore, performing feature transformation and message passing on\n the batched graph is equivalent to doing those\n on all ``DGLGraph`` constituents in parallel. \n\n - Duplicate references to the same graph are\n treated as deep copies; the nodes, edges, and features are duplicated,\n and mutation on one reference does not affect the other. \n - The batched graph keeps track of the meta\n information of the constituents so it can be\n :func:`~dgl.batched_graph.unbatch`\\ ed to list of ``DGLGraph``\\ s.

\n\nStep 2: Tree-LSTM cell with message-passing APIs\n------------------------------------------------\n\nResearchers have proposed two types of Tree-LSTMs: Child-Sum\nTree-LSTMs, and $N$-ary Tree-LSTMs. In this tutorial you focus \non applying *Binary* Tree-LSTM to binarized constituency trees. This \napplication is also known as *Constituency Tree-LSTM*. Use PyTorch \nas a backend framework to set up the network.\n\nIn `N`-ary Tree-LSTM, each unit at node $j$ maintains a hidden\nrepresentation $h_j$ and a memory cell $c_j$. The unit\n$j$ takes the input vector $x_j$ and the hidden\nrepresentations of the child units: $h_{jl}, 1\\leq l\\leq N$ as\ninput, then update its new hidden representation $h_j$ and memory\ncell $c_j$ by: \n\n\\begin{align}i_j & = & \\sigma\\left(W^{(i)}x_j + \\sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\\right), & (1)\\\\\n f_{jk} & = & \\sigma\\left(W^{(f)}x_j + \\sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \\right), & (2)\\\\\n o_j & = & \\sigma\\left(W^{(o)}x_j + \\sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \\right), & (3) \\\\\n u_j & = & \\textrm{tanh}\\left(W^{(u)}x_j + \\sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \\right), & (4)\\\\\n c_j & = & i_j \\odot u_j + \\sum_{l=1}^{N} f_{jl} \\odot c_{jl}, &(5) \\\\\n h_j & = & o_j \\cdot \\textrm{tanh}(c_j), &(6) \\\\\\end{align}\n\nIt can be decomposed into three phases: ``message_func``,\n``reduce_func`` and ``apply_node_func``.\n\n#### Note

``apply_node_func`` is a new node UDF that has not been introduced before. In\n ``apply_node_func``, a user specifies what to do with node features,\n without considering edge features and messages. In a Tree-LSTM case,\n ``apply_node_func`` is a must, since there exists (leaf) nodes with\n $0$ incoming edges, which would not be updated with \n ``reduce_func``.

\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch as th\nimport torch.nn as nn\n\nclass TreeLSTMCell(nn.Module):\n def __init__(self, x_size, h_size):\n super(TreeLSTMCell, self).__init__()\n self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)\n self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)\n self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))\n self.U_f = nn.Linear(2 * h_size, 2 * h_size)\n\n def message_func(self, edges):\n return {'h': edges.src['h'], 'c': edges.src['c']}\n\n def reduce_func(self, nodes):\n # concatenate h_jl for equation (1), (2), (3), (4)\n h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)\n # equation (2)\n f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())\n # second term of equation (5)\n c = th.sum(f * nodes.mailbox['c'], 1)\n return {'iou': self.U_iou(h_cat), 'c': c}\n\n def apply_node_func(self, nodes):\n # equation (1), (3), (4)\n iou = nodes.data['iou'] + self.b_iou\n i, o, u = th.chunk(iou, 3, 1)\n i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)\n # equation (5)\n c = i * u + nodes.data['c']\n # equation (6)\n h = o * th.tanh(c)\n return {'h' : h, 'c' : c}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 3: Define traversal\n------------------------\n\nAfter you define the message-passing functions, induce the\nright order to trigger them. This is a significant departure from models\nsuch as GCN, where all nodes are pulling messages from upstream ones\n*simultaneously*.\n\nIn the case of Tree-LSTM, messages start from leaves of the tree, and\npropagate/processed upwards until they reach the roots. A visualization\nis as follows:\n\n.. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif\n :alt:\n\nDGL defines a generator to perform the topological sort, each item is a\ntensor recording the nodes from bottom level to the roots. One can\nappreciate the degree of parallelism by inspecting the difference of the\nfollowings:\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# to heterogenous graph\ntrv_a_tree = dgl.graph(a_tree.edges())\nprint('Traversing one tree:')\nprint(dgl.topological_nodes_generator(trv_a_tree))\n\n# to heterogenous graph\ntrv_graph = dgl.graph(graph.edges())\nprint('Traversing many trees at the same time:')\nprint(dgl.topological_nodes_generator(trv_graph))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import dgl.function as fn\nimport torch as th\n\ntrv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)\ntraversal_order = dgl.topological_nodes_generator(trv_graph)\ntrv_graph.prop_nodes(traversal_order,\n message_func=fn.copy_src('a', 'a'),\n reduce_func=fn.sum('a', 'a'))\n\n# the following is a syntax sugar that does the same\n# dgl.prop_nodes_topo(graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Note

Before you call :meth:`~dgl.DGLGraph.prop_nodes`, specify a\n `message_func` and `reduce_func` in advance. In the example, you can see built-in\n copy-from-source and sum functions as message functions, and a reduce\n function for demonstration.

\n\nPutting it together\n-------------------\n\nHere is the complete code that specifies the ``Tree-LSTM`` class.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class TreeLSTM(nn.Module):\n def __init__(self,\n num_vocabs,\n x_size,\n h_size,\n num_classes,\n dropout,\n pretrained_emb=None):\n super(TreeLSTM, self).__init__()\n self.x_size = x_size\n self.embedding = nn.Embedding(num_vocabs, x_size)\n if pretrained_emb is not None:\n print('Using glove')\n self.embedding.weight.data.copy_(pretrained_emb)\n self.embedding.weight.requires_grad = True\n self.dropout = nn.Dropout(dropout)\n self.linear = nn.Linear(h_size, num_classes)\n self.cell = TreeLSTMCell(x_size, h_size)\n\n def forward(self, batch, h, c):\n \"\"\"Compute tree-lstm prediction given a batch.\n\n Parameters\n ----------\n batch : dgl.data.SSTBatch\n The data batch.\n h : Tensor\n Initial hidden state.\n c : Tensor\n Initial cell state.\n\n Returns\n -------\n logits : Tensor\n The prediction of each node.\n \"\"\"\n g = batch.graph\n # to heterogenous graph\n g = dgl.graph(g.edges())\n # feed embedding\n embeds = self.embedding(batch.wordid * batch.mask)\n g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)\n g.ndata['h'] = h\n g.ndata['c'] = c\n # propagate\n dgl.prop_nodes_topo(g,\n message_func=self.cell.message_func,\n reduce_func=self.cell.reduce_func,\n apply_node_func=self.cell.apply_node_func)\n # compute logits\n h = self.dropout(g.ndata.pop('h'))\n logits = self.linear(h)\n return logits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Main Loop\n---------\n\nFinally, you could write a training paradigm in PyTorch.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\ndevice = th.device('cpu')\n# hyper parameters\nx_size = 256\nh_size = 256\ndropout = 0.5\nlr = 0.05\nweight_decay = 1e-4\nepochs = 10\n\n# create the model\nmodel = TreeLSTM(trainset.num_vocabs,\n x_size,\n h_size,\n trainset.num_classes,\n dropout)\nprint(model)\n\n# create the optimizer\noptimizer = th.optim.Adagrad(model.parameters(),\n lr=lr,\n weight_decay=weight_decay)\n\ndef batcher(dev):\n def batcher_dev(batch):\n batch_trees = dgl.batch(batch)\n return SSTBatch(graph=batch_trees,\n mask=batch_trees.ndata['mask'].to(device),\n wordid=batch_trees.ndata['x'].to(device),\n label=batch_trees.ndata['y'].to(device))\n return batcher_dev\n\ntrain_loader = DataLoader(dataset=tiny_sst,\n batch_size=5,\n collate_fn=batcher(device),\n shuffle=False,\n num_workers=0)\n\n# training loop\nfor epoch in range(epochs):\n for step, batch in enumerate(train_loader):\n g = batch.graph\n n = g.number_of_nodes()\n h = th.zeros((n, h_size))\n c = th.zeros((n, h_size))\n logits = model(batch, h, c)\n logp = F.log_softmax(logits, 1)\n loss = F.nll_loss(logp, batch.label, reduction='sum') \n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n pred = th.argmax(logits, 1)\n acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)\n print(\"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |\".format(\n epoch, step, loss.item(), acc))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To train the model on a full dataset with different settings (such as CPU or GPU),\nrefer to the `PyTorch example