Attention Craving RNNS: Building Up To Transformer Networks

RNNs let us model sequences in neural networks. While there are other ways of modeling sequences, RNNs are particularly useful. RNNs come in two flavors, LSTMs (Hochreiter et al, 1997) and GRUs (Cho et al, 2014)

By William Falcon, PhD Researcher, AI researcher and AI writer for Forbes


Adding attention to your neural networks is a bit like wanting to take an afternoon nap at work. You know it’s better for you, everyone wants to do it, but everyone’s too scared to.

My goal today is to assume nothing, explain the details with animations, and make math great again (MMGA? ugh…)

Here we’ll cover:

  1. Short RNN review.
  2. Short sequence to sequence model review.
  3. Attention in RNN's.
  4. Improvements to attention.
  5. Transformer network introduction.


Recurrent Neural Networks (RNN)


RNNs let us model sequences in neural networks. While there are other ways of modeling sequences, RNNs are particularly useful. RNNs come in two flavors, LSTM's (Hochreiter et al, 1997) and GRUs (Cho et al, 2014). For a deep tutorial, check out Chris Colah’s tutorial.

Let’s look at machine translation for a concrete example of an RNN.

  1. Imagine we have an RNN with 56 hidden units.
    rnn_cell = rnn_cell(input_dim=100, output_dim=56)
  3. We have a word “NYU” which is represented by the integer 12 meaning it’s the 12th word in the vocab I created.
    # 'NYU' is the 12th word in my vocab
    word = 'NYU'
    word = VOCAB[word]
    # 11

Except we don’t feed an integer into the RNN, we use a higher dimensional representation which we currently obtain through embeddings. An embedding lets us map a sequence of discrete tokens into continuous space (Bengio et al, 2003).

embedding_layer = Embedding(vocab_size=120, embedding_dim=10)

# project our word to 10 dimensions
x = embedding_layer(x)

An RNN cell takes in two inputs, a word x, and a hidden state from the previous time step h. At every time step, it outputs a new h.


RNN CELL: next_h= f(x, prev_h).

*Tip: For the first step h is normally just zeros.

# 1 word, RNN has 56 hidden units
h_0 = np.zeros(1, 56)

This is important: RNN cell is DIFFERENT from an RNN.

There’s a MAJOR point of confusion in RNN terminology. In deep learning frameworks like Pytorch and Tensorflow, the RNN CELL is the unit that performs this computation:

h1 = rnn_cell(x, h0)

the RNN NETWORK for loops the cell over the time steps

def RNN(sentence):
  prev_h = h_0

  all_h = []
  for word in sentence:
    # use the RNN CELL at each time step
    current_h = rnn_cell(embed(word), prev_h)

  # RNNs output a hidden vector h at each time step
  return all_h

Here’s an illustration of an RNN moving the same RNN cell over time:


The RNN moves the RNN cell over time. For attention, we’ll use ALL the h’s produced at each timestep


Sequence To Sequence Models (Seq2Seq)


Now you’re a pro at RNNs, but let’s take it easy for a minute.



RNNs can be used as blocks into larger deep learning systems.

One such system is a Seq2Seq model introduced by Bengio’s group (Cho et al, 2014) and Google (Sutskever et al, 2014), which can be used to translate a sequence to another. You can frame a lot of problems as translation:

  1. Translate English to Spanish.
  2. Translate a video sequence into another sequence.
  3. Translate a sequence of instructions into programming code.
  4. Translate user behavior into future user behavior
  5. The only limit is your creativity!

A seq2seq model is nothing more than 2 RNNs, an encoder (E), and decoder (D).

class Seq2Seq(object):

  def __init__():
      self.encoder = RNN(...)
      self.decoder = RNN(...)

The seq2seq model has 2 major steps:

Step 1: Encode a sequence:

sentence = ["NYU", "NLP", "rocks", "!"]
all_h = Seq2Seq.encoder(sentence)

# all_h now has 4 h (activations)



Step 2: Decode to generate a“translation.”

This part gets really involved. The encoder in the previous step processed the full sequence at once (ie: it was a vanilla RNN).

In this second step, we run the decoder RNN one step at a time to generate predictions autoregressively (this is fancy for using the output of the previous step as the input to the next step).

There are two major ways of doing the decoding:

Option 1: Greedy Decoding

  1. Run 1 step of the decoder.
  2. Pick the highest probability output.
  3. Use this output as the input to the next step
# you have to seed the first x since there are no predictions yet
# SOS means start of sentence
current_X_token = ''

# we also use the last hidden output of the encoder (or set to zero)
h_option_1 = hs[-1]
h_option_2 = zeros(...)

# let's use option 1 where it's the last h produced by the encoder
dec_h = h_option_1

# run greedy search until the RNN generates an End-of-Sentence token
while current_X_token != 'EOS':

   # keep the output h for next step
   next_h = decoder(dec_h, current_X_token)

   # use new h to find most probable next word using classifier
   next_token = max(softmax(fully_connected_layer(next_h)))

   # *KEY* prepare for next pass by updating pointers
   current_X_token = next_token
   dec_h = next_h

It’s called greedy because we always go with the highest probability next word.

Option 2: Beam Search

There’s a better technique called Beam Search, which considers multiple paths through the decoding process. Colloquially, a beam search of width 5 means we consider 5 possible sequences with the maximum log likelihood (math talk for 5 most probable sequences).

At a high-level, instead of taking the highest probability prediction, we keep the top k (beam size = k). Notice below, at each step we have 5 options (5 with the highest probability).


Beam search figure found here

This youtube video has a detailed beam search tutorial!

So, the full seq2seq process with greedy decoding as an animation to translate “NYU NLP is awesome” into Spanish looks like this:


Seq2Seq is made up of 2 RNNs an encoder and decoder

This model has various parts:

  1. Blue RNN is the encoder.
  2. Red RNN is the decoder
  3. The blue rectangle on top of the decoder is a fully connected layer with a softmax. This picks the most likely next word.