Fixing a Major Weakness in Machine Learning of Images with Hinton’s Capsule Networks
We explore Geoffrey Hinton's capsule networks to deal with rotational variance in images.
By Kevin Vu, ExxactCorp.
Have a look at this:
Now see this:
Even if you’ve never been to the moon, you can probably recognize the subject of the images above as NASA’s Lunar Roving Vehicle, or at least as being two instances of an identical vehicle at slightly different orientations. You probably have an intuitive idea of how you could manipulate the viewpoint of one image to approximate the view of the other. This sort of cognitive transformation is effortlessly intuitive for a human, but turns out to be very difficult for a convolutional neural network without explicit training examples.
Limitations of Convolutional and Max Pooling Layers
Standard convolutional neural networks are made up of, as the name suggests, a series of convolution operations that hierarchically extract image features like edges, points, and corners. Each convolution multiplies the image by a sliding window of pixel weights, aka a convolution kernel, and there may be tens to thousands of kernels in each layer. Often, we perform a pooling operation in between each convolution, decreasing image dimensions. Pooling not only decreases the size of the layers (saving memory), but provides some translation invariance so that a given network can classify an image subject regardless of where it resides in the image. This may be more of a bug than a feature, however, as pooling operations confuse information about where something is in an image (driving the development of skip connections in U-nets) and fare poorly coping with image transformations other than translation.
Translation invariance in conv-nets with pooling falls short of object transformation equivariance, a more generalized cognitive ability that seems to be closer to our own approach to making sense of the world. The fact that conv-nets perform pretty well at a wide variety of computer vision tasks glosses over this shortcoming. Consider the classic example of the MNIST hand-written digits dataset. LeNet-5, a relatively shallow and simple conv-net design by today’s standards, quickly learns to correctly identify 98% of the digits in the test dataset.
Apply a simple 35 degree rotation to the test images, however, and the test performance drops precipitously.
A so-called “Capsule Network” does somewhat better with rotated data:
The standard approach to mitigate the problem is data augmentation, that is, adding rotations, mirroring, distortion, etc. to synthetically enlarge the dataset to cover a larger distribution of possible examples. This improves performance on a given vision task, but it’s clearly a kludge, and, as they say “intellectually unsatisfying.”
For many years Geoffrey Hinton has been outspoken in his dislike for pooling operations, and has been trying to replace the happenstance translational invariance of pooling with a more universal equivariance with what he terms “capsules,” a representation of scene contents created by reshaping the features extracted by convolution into multidimensional vectors. The concept of capsule networks has evolved alongside the upsurge in conv-nets as transforming autoencoders (2011) to a dynamic routing method for training capsules (2017), and most recently with an updated training algorithm termed expectation maximalization (2018).
Capsules to the Rescue?
In capsule networks, each vector learns to represent some aspect of the image, such as shape primitives, with vector length corresponding to the probability of the object existing at a given point, and the direction of the vector describing the object’s characteristics. In the 2017 implementation, the first layer of capsules each try to predict the correct probabilities for the next layer of capsules via dynamic routing (e.g. in a face detection CapsNet the “eye” and “nose” capsule values will each contribute to the prediction of the “face” capsule in the next layer for each point). Consider the simplified example of 2D capsule vectors detecting polygons that make up cartoon doorways. These capsules represent the presence and orientation of two shapes, blocks and quarter circles, and together they will try to predict the correct classification in the next capsule layer, which learns to detect a properly oriented doorway.
Whereas in a conv-net the mere presence of the correct features (in orientations that are represented in the training data) is enough to trigger a classification regardless of their spatial relationship to one another, capsule vectors all have to be in strong agreement to predict the whole from its parts. We should also take note that a capsule can only detect one instance of a given object at a time, so a pile of blocks would be indistinguishable and CapsNet models can get confused by overlapping parts of the same type. This shortcoming is often compared to crowding in human perception.
Tutorial Section: Training and Testing LeNet5 vs. Dynamic Routing CapsNet for Rotated MNIST Classification
Even better than talking about capsules is tinkering with them. To keep things simple, we’ll be working with the popular MNIST handwritten digits dataset. The code in this section provides a hackable foundation for understanding CapsNets in the context of a familiar dataset and machine learning model in the 5-layer LeNet5 conv-net. After getting a general overview of CapsNet performance on MNIST, we’d recommend adding different training data augmentation routines to see how well each model takes to learning various transformations.
First we’ll define the dataset we want to work with and the preprocessing we need, using PyTorch’s transform library.
Defining our CNN: LeNet5
We’ll start by implementing a small convolutional neural network called LeNet5 in PyTorch. This model gives us test set accuracy in the high 90s after only a few training epochs, and consists of just 2 convolutional and 3 fully connected layers.
Training and Test Routines for LeNet5
We’ll use Adam optimization to minimize cross-entropy error during training. Again, this functionality is readily accessible via PyTorch.
The dynamic routing algorithm for training capsule networks is more computationally demanding than for conv-nets. We’ll definitely want to train on a GPU if we want to finish in a reasonable amount of time. We’ve also pre-trained a CapsNet for those who find themselves between GPUs at the moment or just want to skip to testing. For training and testing a capsule network, we forked and modified the implementation at https://github.com/gram-ai/capsule-networks by Kenta Iwasaki. Clone the version used in this tutorial by entering (in the command line):
After that, you’ll probably want to spin up a PyTorch visdom server for visualization purposes by entering (in a separate command line window):
Finally, you can train and test the CapsNet by entering the code below into an interactive python session (still in the capsule_networks_rotated_MNIST directory), or save it as a .py to play around with and run it from the command line with:
where run_capsnet.py is the name of the newly saved script file.
Capsule Networks provide an extension of the universal feature extraction properties of convolutional neural networks. By training each primary capsule to predict the output of the next layer’s capsules, the model can be encouraged to learn to recognize the relationships between parts, wholes, and the importance of their instantiation characteristics such as location and orientation. In many ways this feels like a more natural way to recognize the objects in a scene, as orientations and other attributes can be learned as parameters of a scene object represented by the capsules, and modifying the characteristics can give us realistic changes in viewpoint, scale, etc. Convolution activations start to seem like a pretty crude level of feature representation by comparison.
The dynamic routing algorithm used for training can be painfully slow (one epoch can take over five minutes vs 25 seconds for LeNet5 on same hardware), however, and in practice it can take a bit of selective representation (aka cherry-picking) to find situations where CapsNets are decidedly better than a comparable conv-net. Data augmentation can yield greater than 98% accuracy across training and (rotated) test MNIST datasets with a simple conv-net like LeNet5, and it may be more fair to compare CapsNets to conv-nets based on training time required rather than model size. Overall, the difference between 98% and upper 99% accuracy may not seem like much, but it’s those last few percentage points of error that matter most in terms of solving a problem rather than learning an approximate heuristic.
There’s still plenty of room for improvements to training CapsNets, and the high level of interest ensures that they will receive plenty of development effort. We’ll probably see CapsNets gain utility in a similar way that conv-nets did, first being demonstrated on toy problems like MNIST before application to more relevant domains. One thing that’s sure to yield exciting results with CapsNets is a combination of faster hardware accelerators and better training algorithms/software libraries to allow “Deep CapsNets” to become practical.
Images of the Lunar Roving Vehicle in the public domain (generated by NASA) obtained from https://commons.wikimedia.org/wiki/File:Apollo_15_Lunar_Rover_final_resting_place.jpg and https://commons.wikimedia.org/wiki/Category:Lunar_Roving_Vehicle#/media/File:Apollo_17_lunar_rover_near_station_8_AS17-146-22367HR.jpg
The PyTorch implementation of a dynamic routing CapsNet was forked and modified from a public repository by Kenta Iwasaki @ Gram.AI: https://github.com/gram-ai/capsule-networks
Original. Reposted with permission.
- On-line and web-based: Analytics, Data Mining, Data Science, Machine Learning education
- Software for Analytics, Data Science, Data Mining, and Machine Learning
- How do you teach physics to machine learning models?
- Building a Computer Vision Model: Approaches and datasets
- Think Like an Amateur, Do As an Expert: Lessons from a Career in Computer Vision
Top Stories Past 30 Days