Recursive (not Recurrent!) Neural Networks in TensorFlow
Learn how to implement recursive neural networks in TensorFlow, which can be used to learn tree-like structures, or directed acyclic graphs.
By Alireza Nejati, University of Auckland.
For the past few days I’ve been working on how to implement recursive neural networks in TensorFlow. Recursive neural networks (which I’ll call TreeNets from now on to avoid confusion with recurrent neural nets) can be used for learning tree-like structures (more generally, directed acyclic graph structures). They are highly useful for parsing natural scenes and language; see the work of Richard Socher (2011) for examples. More recently, in 2014, Ozan İrsoy used a deep variant of TreeNets to obtain some interesting NLP results.
The best way to explain TreeNet architecture is, I think, to compare with other kinds of architectures, for example with RNNs:
In RNNs, at each time step the network takes as input its previous state s(t-1) and its current input x(t) and produces an output y(t) and a new hidden state s(t). TreeNets, on the other hand, don’t have a simple linear structure like that. With RNNs, you can ‘unroll’ the net and think of it as a large feedforward net with inputs x(0), x(1), …, x(T), initial state s(0), and outputs y(0),y(1),…,y(T), with T varying depending on the input data stream, and the weights in each of the cells tied with each other. You can also think of TreeNets by unrolling them – the weights in each branch node are tied with each other, and the weights in each leaf node are tied with each other. The TreeNet illustrated above has different numbers of inputs in the branch nodes. Usually, we just restrict the TreeNet to be a binary tree – each node either has one or two input nodes. There may be different types of branch nodes, but branch nodes of the same type have tied weights.
The advantage of TreeNets is that they can be very powerful in learning hierarchical, tree-like structure. The disadvantages are, firstly, that the tree structure of every input sample must be known at training time. We will represent the tree structure like this (lisp-like notation):
In each sub-expression, the type of the sub-expression must be given – in this case, we are parsing a sentence, and the type of the sub-expression is simply the part-of-speech (POS) tag. You can see that expressions with three elements (one head and two tail elements) correspond to binary operations, whereas those with four elements (one head and three tail elements) correspond to trinary operations, etc.
The second disadvantage of TreeNets is that training is hard because the tree structure changes for each training sample and it’s not easy to map training to mini-batches and so on.
Implementation in TensorFlow
There are a few methods for training TreeNets. The method we’re going to be using is a method that is probably the simplest, conceptually. It consists of simply assigning a tensor to every single intermediate form. So, for instance, imagine that we want to train on simple mathematical expressions, and our input expressions are the following (in lisp-like notation):
Now our full list of intermediate forms is:
f = (* 1 2), and
g = (+ (* 1 2) (+ 2 1)). We can see that all of our intermediate forms are simple expressions of other intermediate forms (or inputs). Each of these corresponds to a separate sub-graph in our tensorflow graph. So, for instance, for
*, we would have two matrices
W_times_r, and one bias vector
bias_times. And for computing
f, we would have:
Similarly, for computing d we would have:
The full intermediate graph (excluding input and loss calculation) looks like:
For training, we simply initialize our inputs and outputs as one-hot vectors (here, we’ve set the symbol
[1, 0] and the symbol
[0, 1]), and perform gradient descent over all W and bias matrices in our graph. The advantage of this method is that, as I said, it’s straightforward and easy to implement. The disadvantage is that our graph complexity grows as a function of the input size. This isn’t as bad as it seems at first, because no matter how big our data set becomes, there will only ever be one training example (since the entire data set is trained simultaneously) and so even though the size of the graph grows, we only need a single pass through the graph per training epoch. However, it seems likely that if our graph grows to very large size (millions of data points) then we need to look at batch training.
Batch training actually isn’t that hard to implement; it just makes it a bit harder to see the flow of information. We can represent a ‘batch’ as a list of variables:
[a, b, c]. So, in our previous example, we could replace the operations with two batch operations:
You’ll immediately notice that even though we’ve rewritten it in a batch way, the order of variables inside the batches is totally random and inconsistent. This is the problem with batch training in this model: the batches need to be constructed separately for each pass through the network. If we think of the input as being a huge matrix where each row (or column) of the matrix is the vector corresponding to each intermediate form (so
[a, b, c, d, e, f, g]) then we can pick out the variables corresponding to each batch using tensorflow’s
tf.gather function. So for instance, gathering the indices
[1, 0, 3] from
[a, b, c, d, e, f, g]would give
[b, a, d], which is one of the sub-batches we need. The total number of sub-batches we need is two for every binary operation and one for every unary operation in the model.
For the sake of simplicity, I’ve only implemented the first (non-batch) version in TensorFlow, and my early experiments show that it works. For example, consider predicting the parity (even or odd-ness) of a number given as an expression. So
1would have parity 1,
(+ 1 1) (which is equal to 2) would have parity 0,
(+ 1 (* (+ 1 1) (+ 1 1))) (which is equal to 5) would have parity 1, and so on. Training a TreeNet on the following small set of training examples:
Seems to be enough for it to ‘get the point’ of parity, and it is capable of correctly predicting the parity of much more complicated inputs, for instance:
Correctly, with very high accuracy (>99.9%), with accuracy only diminishing once the size of the inputs becomes very large. The code is just a single python file which you can download and run here. I’ll give some more updates on more interesting problems in the next post and also release more code.
Bio: Al Nejati is a research fellow at the University of Auckland. He completed his PhD in engineering science in 2015. He is interested in machine learning, image/signal processing, Bayesian statistics, and biomedical engineering.
Original. Reposted with permission.