Attention Craving RNNS: Building Up To Transformer Networks
RNNs let us model sequences in neural networks. While there are other ways of modeling sequences, RNNs are particularly useful. RNNs come in two flavors, LSTMs (Hochreiter et al, 1997) and GRUs (Cho et al, 2014)
Ok, now that we’ve covered all the prereqs, let’s get to the good stuff.
If you noticed on the previous animation, the decoder only looked at the last hidden vector generated by the encoder.
Turns out, it’s hard for the RNN to remember everything that happened over the sequence in this single vector (Bahdanau et al, 2015). For example, the word “NYU” could have been forgotten by the encoder by the time it finishes processing the input sequence.
Attention tries to solve this problem.
When you give a model an attention mechanism you allow it to look at ALL the h’s produced by the encoder at EACH decoding step.
To do this, we use a separate network, usually 1 fully connected layer which calculates how much of all the h’s the decoder wants to look at. This is called the attention mechanism.
So imagine that for all the h’s we generated, we’re actually only going to take a bit of each. Their sum is called a context vector c.
The scalars 0.3, 0.2, 0.4, 0.1 are called attention weights. In the original paper, you’ll find the same equation on page 3:
those weights are generated by a small neural network in this way:
Now that we have the weights, we use them to pull out the h’s which might be relevant for that particular token being decoded
context_vector = attn_weights.dot(all_h) # this is now a vector which mixes a bit of all the h's
Let’s break it down into steps:
- We encoded the full input sequence and generated a list of h’s.
- We started decoding with the decoder using greedy search.
- Instead of giving the decoder h4, we gave it a context vector.
- To generate the context vector, we used another network and learnable weights V to score how relevant each h was to the current token being decoded.
- We normalized those attention energies and used them to mix all the h’s into 1 h which hopefully captures the relevant parts of all the hs, ie: a context vector.
- Now we perform the decoding step again, but this time, using the context vector instead of h4.
The attention weights tell us how important each h is. This means we can also visualize the weights at each decoding step. Here’s an example from the original attention paper:
In the first row, to translate “L’” the network used an alpha on the word “The” and zeroed the rest out.
To generate the word “economique” the network actually put some weight of the alphas at “European Economic” and zeroed out the rest. This shows attention’s usefulness when the translation relationship is many-to-one or one-to-many.
Attention Can Get Complicated
Types of Attention
This type of attention used only the h’s generated by the encoder. There’s a ton of research on improving on that process. For example:
- Use only some of the h’s, maybe the h’s around the time step you’re currently decoding (local attention).
- In addition to the h’s also use the h’s being generated by the decoder which we were throwing away before.
How to calculate attention energies
Another research area deals with how to calculate the attention scores. Instead of a dot product with V, researchers have also tried:
- Scaling dot products.
- Cosine(s, h)
- Not using the V matrix and applying the softmax to the fully connected layer.
What to use when calculating attention energies
This final area of research looked at what exactly should go into comparing with the h vectors.
To build some intuition about what I mean, think about calculating attention like a key-value dictionary. The key is what you give the attention network to “look up” the most relevant context. The value is the most relevant context.
The method I described here only uses the current token and each h to compute the attention score. That is:
But really we could give it anything we might think is useful to help the attention network make the best decision. Maybe we give it the last context vector also!
or maybe we give it something different, maybe a token to let it know it’s decoding Spanish
The possibilities are endless!
Here are some tips to think about if you decide to implement your own.
- Use Facebook’s implementation which is already really optimized.
Ok, fine that was a cop-out. Here are actual tips.
- Remember the seq2seq has two parts a decoder RNN and encoder RNN. These two are separate.
- The bulk of the work goes into building the decoder. The encoder is simply running the encoder over the full input sequence.
- Remember the decoder RNN operates one step at a time. This is key!
- Remember the decoder RNN operates one step at a time. Worth saying twice ;)
- You have two options for decoding algorithms, greedy or beam search. Greedy is easier to implement, but beam search will give you better results most of the time.
- Attention is optional! BUT… the impact is huge when you have it…
- Attention is a separate network… Think about the network as the dictionary, where the key is a collection of things you want the network to use in deciding how relevant each particular h is.
- Remember you are calculating attention for each h. That means you have a for loop for [h1, …, hn].
- The attention network embedding dim can be made arbitrarily high. This WILL blow up your RAM. Make sure to put it on a separate GPU or keep the dim small.
- A trick to get large models going is to put the encoder on 1 gpu, decoder on a second gpu and the attention network on a third gpu. This way you keep the memory footprint low.
- If you actually deploy this model, you’ll need to implement it batched. Everything I explained here was for batch size=1, but you can scale to bigger batches by changing to tensor products and being smart about your linear algebra. I explain this in detail here.
Again, most of the time, you should just use an open source implementation, but it’s a great learning experience to do your own!
Life After Attention
Turns out… the attention network by itself was shown to be really powerful.
So much so, that researchers decided to get rid of the RNNs and the sequence to sequence approach. They instead created something called a Transformer model.
At a high-level, a transformer still has an encoder and decoder except the layers are fully connected and look at the full input at once. Then as the input moves through the network, attention heads are applied to focus on what’s important.
Bio: William Falcon is startup founder, AI researcher and AI writer for Forbes. He is passionate about encouraging latinX inclusion in STEM and using AI for social impact. His research interest is in the intersection of AI and neuroscience with a focus on biologically inspired fundamental research in deep learning and reinforcement learning with applications to social problems, neuroscience, and NLP.
Original. Reposted with permission.
- Sequence Modeling with Neural Networks – Part I
- Exploring Recurrent Neural Networks
- A Guide For Time Series Prediction Using Recurrent Neural Networks (LSTMs)