Multi-Task Learning in Tensorflow: Part 1
A discussion and step-by-step tutorial on how to use Tensorflow graphs for multi-task learning.
By Jonathan Godwin, University College London.
A Jupyter notebook accompanies this blog post. Please download here.
Why Multi-Task Learning
When you think about the way people learn to do new things, they often use their experience and knowledge of the world to speed up the learning process. When I learn a new language, especially a related one, I use my knowledge of languages I already speak to make shortcuts. The process works the other way too - learning a new language can help you understand and speak your own better.
Our brains learn to do multiple different tasks at the same time - we have the same brain architecture whether we are translating English to German or English to French. If we were to use a Machine Learning algorithm to do both of these tasks, we might call that ‘multi-task’ learning.
It’s one of the most interesting and exciting areas of research for Machine Learning in coming years, radically reducing the amount of data required to learn new concepts. One of the great promises of Deep Learning is that, with the power of the models and simple ways to share parameters between tasks, we should be able to make significant progress in multi-task learning.
As I started to experiment in this area I came across a bit of a road block - while it was easy to understand the architecture changes required to implement multi-task learning, it was harder to figure out how to implement it in Tensorflow. To do anything but standard nets in Tensorflow requires a good understanding of how it works, but most of the stock examples don’t provide helpful guidance. I hope the following tutorial explains some key concepts simply, and helps those who are struggling.
What We Are Going To Do
- Understand Tensorflow Computation Graphs With An Example. Doing multi-task learning with Tensorflow requires understanding how computation graphs work - skip if you already know.
- Understand How We Can Use Graphs For Multi-Task Learning. We’ll go through an example of how to adapt a simple graph to do Multi-Task Learning.
- Build A Graph for POS Tagging and Shallow Parsing. We’ll fill in a template that trains a net for two related linguistic tasks. Don’t worry, you don’t need to know what they are!
- Train A Net Jointly and Separately. We’ll actually train a model in two different ways. You should be able to do this on your laptop.
Understanding Computation Graphs With A Toy Example
The Computation Graph is the thing that makes Tensorflow (and other similar packages) fast. It’s an integral part of machinery of Deep Learning, but can be confusing.
There are some neat features of a graph that mean it’s very easy to conduct multi-task learning, but first we’ll keep things simple and explain the key concepts.
Definition: Computation Graph
The Computation Graph is a template for computation (re: algorithm) you are going to run. It doesn’t perform any calculations, but it means that your computer can conduct backpropagation far more quickly.
If you ask Tensorflow for a result of a calculation it will only make those calculations required for the job, not the whole graph.
A Toy Example - Linear Transformation: Setting Up The Graph
We’re going to look at the graph for a simple calculation - a linear transformation of our inputs, and taking the square loss:
There are a few things to emphasis about this graph:
- If we were to run this code right now, we would get no output. Remember that a Computation Graph is just a template - it doesn’t do anything. If we want an answer, we have to tell Tensorflow to run the computation using a Session.
- We haven’t explicitly created a graph object. You might expect that we would have to create a graph object somewhere in order for Tensorflow to know that we wanted to create a graph. In fact, by using the Tensorflow operations, we are telling Tensorflow what parts of our code are in the graph.
Tip: Keep Your Graph Separate. You’ll typically be doing a fair amount of data manipulation and computation outside of the graph, which means keeping track of what is and isn’t available inside of python a bit confusing. I like to put my graph in a separate file, and often in a separate class to keep concerns separated, but this isn’t required.
A Toy Example - Linear Transformation: Getting Results
Computations on your Graph are conducted inside a Tensorflow Session. To get results from your session you need to provide it with two things: Target Results and Inputs.
- Target Results or Operations. You tell Tensorflow what parts of the graph you want to return values for, and it will automatically figure out what calculations within need to be run. You can also call operations, for example, to initialise your variables.
- Inputs As Required (‘Feed Dict’). In most calculations you will provide the input data ad-hoc. In this case, you construct the graph with aplaceholder for this data, and feed it in at computation time. Not all calculations or operations will require an input - for many, all the information is already contained in the graph.
How To Use Graphs for Multi-Task Learning
When we create a Neural Net that performs multiple tasks we want to have some parts of the network that are shared, and other parts of the network that are specific to each individual task. When we’re training, we want information from each task to be transferred in the shared parts of the network.
So, to start, let’s draw a diagram of a simple two-task network that has a shared layer and a specific layer for each individual task. We’re going to feed the outputs of this into our loss function with our targets. I’ve labelled where we’re going to want to create placeholders in the graph.
When we are training this network, we want the parameters of the Task 1 layer to not change no matter how wrong we get Task 2, but the parameters of the shared layer to change with both tasks. This might seem a little difficult - normally you only have one optimiser in a graph, because you only optimise one loss function. Thankfully, using the properties of the graph it’s very easy to train this sort of model in two ways.