Deep Learning For Chatbots, Part 2 – Implementing A Retrieval-Based Model In TensorFlow
Check out part 2 of this tutorial on building chatbots with deep neural networks. This part gets practical, and using Python and TensorFlow to implement.
Boilerplate Training Code
Before writing the actual neural network code I like to write the boilerplate code for training and evaluating the model. That’s because, as long as you adhere to the right interfaces, it’s easy to swap out what kind of network you are using. Let’s assume we have a model function
model_fn that takes as inputs our batched features, labels and mode (train or evaluation) and returns the predictions. Then we can write general-purpose code to train our model as follows:
Here we create an estimator for our
model_fn, two input functions for training and evaluation data, and our evaluation metrics dictionary. We also define a monitor that evaluates our model every
FLAGS.eval_every steps during training. Finally, we train the model. The training runs indefinitely, but Tensorflow automatically saves checkpoint files in
MODEL_DIR, so you can stop the training at any time. A more fancy technique would be to use early stopping, which means you automatically stop training when a validation set metric stops improving (i.e. you are starting to overfit). You can see the full code in
Two things I want to mention briefly is the usage of
FLAGS. This is a way to give command line parameters to the program (similar to Python’s argparse).
hparams is a custom object we create in
hparams.py that holds hyperparameters, nobs we can tweak, of our model. This hparams object is given to the model when we instantiate it.
Creating The Model
Now that we have set up the boilerplate code around inputs, parsing, evaluation and training it’s time to write code for our Dual LSTM neural network. Because we have different formats of training and evaluation data I’ve written a
create_model_fnwrapper that takes care of bringing the data into the right format for us. It takes a
model_impl argument, which is a function that actually makes predictions. In our case it’s the Dual Encoder LSTM we described above, but we could easily swap it out for some other neural network. Let’s see what that looks like:
That’s it! We can now run
python udc_train.py and it should start training our networks, occasionally evaluating recall on our validation data (you can choose how often you want to evaluate using the
--eval_every switch). To get a complete list of all available command line flags that we defined using
hparamsyou can run
python udc_train.py --help.
Evaluating The Model
After you’ve trained the model you can evaluate it on the test set using
python udc_test.py --model_dir=$MODEL_DIR_FROM_TRAINING, e.g.
python udc_test.py --model_dir=~/github/chatbot-retrieval/runs/1467389151. This will run the recall@k evaluation metrics on the test set instead of the validation set. Note that you must call
udc_test.pywith the same parameters you used during training. So, if you trained with
--embedding_size=128 you need to call the test script with the same.
After training for about 20,000 steps (around an hour on a fast GPU) our model gets the following results on the test set:
While recall@1 is close to our TFIDF model, recall@2 and recall@5 are significantly better, suggesting that our neural network assigns higher scores to the correct answers. The original paper reported
0.92 for recall@1, recall@2, and recall@5 respectively, but I haven’t been able to reproduce scores quite as high. Perhaps additional data preprocessing or hyperparameter optimization may bump scores up a bit more.
You can modify and run
udc_predict.py to get probability scores for unseen data. For example
python udc_predict.py --model_dir=./runs/1467576365/ outputs:
You could imagine feeding in 100 potential responses to a context and then picking the one with the highest score.
In this post we’ve implemented a retrieval-based neural network model that can assign scores to potential responses given a conversation context. There is still a lot of room for improvement, however. One can imagine that other neural networks do better on this task than a dual LSTM encoder. There is also a lot of room for hyperparameter optimization, or improvements to the preprocessing step. The Code and data for this tutorial is on Github, so check it out.
Original. Reposted with permission.