{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n\nCapsule Network\n===========================\n\n**Author**: Jinjing Zhou, Jake Zhao _, Zheng Zhang, Jinyang Li\n\nIn this tutorial, you learn how to describe one of the more classical models in terms of graphs. The approach\noffers a different perspective. The tutorial describes how to implement a Capsule model for the\ncapsule network __.\n\n

#### Warning

The tutorial aims at gaining insights into the paper, with code as a mean\n of explanation. The implementation thus is NOT optimized for running\n efficiency. For recommended implementation, please refer to the official\n examples _.

\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Key ideas of Capsule\n--------------------\n\nThe Capsule model offers two key ideas: Richer representation and dynamic routing.\n\n**Richer representation** -- In classic convolutional networks, a scalar\nvalue represents the activation of a given feature. By contrast, a\ncapsule outputs a vector. The vector's length represents the probability\nof a feature being present. The vector's orientation represents the\nvarious properties of the feature (such as pose, deformation, texture\netc.).\n\n|image0|\n\n**Dynamic routing** -- The output of a capsule is sent to\ncertain parents in the layer above based on how well the capsule's\nprediction agrees with that of a parent. Such dynamic\nrouting-by-agreement generalizes the static routing of max-pooling.\n\nDuring training, routing is accomplished iteratively. Each iteration adjusts\nrouting weights between capsules based on their observed agreements.\nIt's a manner similar to a k-means algorithm or competitive\nlearning __.\n\nIn this tutorial, you see how a capsule's dynamic routing algorithm can be\nnaturally expressed as a graph algorithm. The implementation is adapted\nfrom Cedric\nChee __, replacing\nonly the routing layer. This version achieves similar speed and accuracy.\n\nModel implementation\n----------------------\nStep 1: Setup and graph initialization\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe connectivity between two layers of capsules form a directed,\nbipartite graph, as shown in the Figure below.\n\n|image1|\n\nEach node $j$ is associated with feature $v_j$,\nrepresenting its capsule\u2019s output. Each edge is associated with\nfeatures $b_{ij}$ and $\\hat{u}_{j|i}$. $b_{ij}$\ndetermines routing weights, and $\\hat{u}_{j|i}$ represents the\nprediction of capsule $i$ for $j$.\n\nHere's how we set up the graph and initialize node and edge features.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch.nn as nn\nimport torch as th\nimport torch.nn.functional as F\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport dgl\n\n\ndef init_graph(in_nodes, out_nodes, f_size):\n u = np.repeat(np.arange(in_nodes), out_nodes)\n v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes)\n g = dgl.DGLGraph((u, v))\n # init states\n g.ndata['v'] = th.zeros(in_nodes + out_nodes, f_size)\n g.edata['b'] = th.zeros(in_nodes * out_nodes, 1)\n return g" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 2: Define message passing functions\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThis is the pseudocode for Capsule's routing algorithm.\n\n|image2|\nImplement pseudocode lines 4-7 in the class DGLRoutingLayer as the following steps:\n\n1. Calculate coupling coefficients.\n\n - Coefficients are the softmax over all out-edge of in-capsules.\n $\\textbf{c}_{i,j} = \\text{softmax}(\\textbf{b}_{i,j})$.\n\n2. Calculate weighted sum over all in-capsules.\n\n - Output of a capsule is equal to the weighted sum of its in-capsules\n $s_j=\\sum_i c_{ij}\\hat{u}_{j|i}$\n\n3. Squash outputs.\n\n - Squash the length of a Capsule's output vector to range (0,1), so it can represent the probability (of some feature being present).\n - $v_j=\\text{squash}(s_j)=\\frac{||s_j||^2}{1+||s_j||^2}\\frac{s_j}{||s_j||}$\n\n4. Update weights by the amount of agreement.\n\n - The scalar product $\\hat{u}_{j|i}\\cdot v_j$ can be considered as how well capsule $i$ agrees with $j$. It is used to update\n $b_{ij}=b_{ij}+\\hat{u}_{j|i}\\cdot v_j$\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import dgl.function as fn\n\nclass DGLRoutingLayer(nn.Module):\n def __init__(self, in_nodes, out_nodes, f_size):\n super(DGLRoutingLayer, self).__init__()\n self.g = init_graph(in_nodes, out_nodes, f_size)\n self.in_nodes = in_nodes\n self.out_nodes = out_nodes\n self.in_indx = list(range(in_nodes))\n self.out_indx = list(range(in_nodes, in_nodes + out_nodes))\n\n def forward(self, u_hat, routing_num=1):\n self.g.edata['u_hat'] = u_hat\n\n for r in range(routing_num):\n # step 1 (line 4): normalize over out edges\n edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)\n self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)\n self.g.edata['c u_hat'] = self.g.edata['c'] * self.g.edata['u_hat']\n\n # Execute step 1 & 2\n self.g.update_all(fn.copy_e('c u_hat', 'm'), fn.sum('m', 's'))\n\n # step 3 (line 6)\n self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1)\n\n # step 4 (line 7)\n v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0)\n self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)\n\n @staticmethod\n def squash(s, dim=1):\n sq = th.sum(s ** 2, dim=dim, keepdim=True)\n s_norm = th.sqrt(sq)\n s = (sq / (1.0 + sq)) * (s / s_norm)\n return s" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 3: Testing\n~~~~~~~~~~~~~~~\n\nMake a simple 20x10 capsule layer.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "in_nodes = 20\nout_nodes = 10\nf_size = 4\nu_hat = th.randn(in_nodes * out_nodes, f_size)\nrouting = DGLRoutingLayer(in_nodes, out_nodes, f_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can visualize a Capsule network's behavior by monitoring the entropy\nof coupling coefficients. They should start high and then drop, as the\nweights gradually concentrate on fewer edges.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "entropy_list = []\ndist_list = []\n\nfor i in range(10):\n routing(u_hat)\n dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes)\n entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)\n entropy_list.append(entropy.data.numpy())\n dist_list.append(dist_matrix.data.numpy())\n\nstds = np.std(entropy_list, axis=1)\nmeans = np.mean(entropy_list, axis=1)\nplt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o')\nplt.ylabel(\"Entropy of Weight Distribution\")\nplt.xlabel(\"Number of Routing\")\nplt.xticks(np.arange(len(entropy_list)))\nplt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "|image3|\n\nAlternatively, we can also watch the evolution of histograms.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import seaborn as sns\nimport matplotlib.animation as animation\n\nfig = plt.figure(dpi=150)\nfig.clf()\nax = fig.subplots()\n\n\ndef dist_animate(i):\n ax.cla()\n sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)\n ax.set_xlabel(\"Weight Distribution Histogram\")\n ax.set_title(\"Routing: %d\" % (i))\n\n\nani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500)\nplt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "|image4|\n\nYou can monitor the how lower-level Capsules gradually attach to one of the\nhigher level ones.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import networkx as nx\nfrom networkx.algorithms import bipartite\n\ng = routing.g.to_networkx()\nX, Y = bipartite.sets(g)\nheight_in = 10\nheight_out = height_in * 0.8\nheight_in_y = np.linspace(0, height_in, in_nodes)\nheight_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)\npos = dict()\n\nfig2 = plt.figure(figsize=(8, 3), dpi=150)\nfig2.clf()\nax = fig2.subplots()\npos.update((n, (i, 1)) for i, n in zip(height_in_y, X)) # put nodes from X at x=1\npos.update((n, (i, 2)) for i, n in zip(height_out_y, Y)) # put nodes from Y at x=2\n\n\ndef weight_animate(i):\n ax.cla()\n ax.axis('off')\n ax.set_title(\"Routing: %d \" % i)\n dm = dist_list[i]\n nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax)\n nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax)\n for edge in g.edges():\n nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge, edge - in_nodes] * 1.5, ax=ax)\n\n\nani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500)\nplt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "|image5|\n\nThe full code of this visualization is provided on\nGitHub __. The complete\ncode that trains on MNIST is also on GitHub __.\n\n.. |image0| image:: https://i.imgur.com/55Ovkdh.png\n.. |image1| image:: https://i.imgur.com/9tc6GLl.png\n.. |image2| image:: https://i.imgur.com/mv1W9Rv.png\n.. |image3| image:: https://i.imgur.com/dMvu7p3.png\n.. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif\n.. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 0 }