Lit BERT: NLP Transfer Learning In 3 Steps
PyTorch Lightning is a lightweight framework which allows anyone using PyTorch to scale deep learning code easily while making it reproducible. In this tutorial we’ll use Huggingface's implementation of BERT to do a finetuning task in Lightning.
By William Falcon, AI Researcher
BERT (Devlin, et al, 2018) is perhaps the most popular NLP approach to transfer learning. The implementation by Huggingface offers a lot of nice features and abstracts away details behind a beautiful API.
PyTorch Lightning is a lightweight framework (really more like refactoring your PyTorch code) which allows anyone using PyTorch such as students, researchers and production teams, to scale deep learning code easily while making it reproducible. It also provides 42+ advanced research features via trainer flags.
Lightning does not add abstractions on to of PyTorch which means it plays nicely with other great packages like Huggingface! In this tutorial we’ll use their implementation of BERT to do a finetuning task in Lightning.
In this tutorial we’ll do transfer learning for NLP in 3 steps:
- We’ll import BERT from the huggingface library.
- We’ll create a LightningModule which finetunes using features extracted by BERT
- We’ll train the BertMNLIFinetuner using the Lighting Trainer.
If you’d rather see this in actual code, copy this colab notebook!
Finetuning (aka transfer learning)
If you’re a researcher trying to improve on the NYU GLUE benchmark, or a data scientist trying to understand product reviews to recommend new content, you’re looking for a way to extract a representation of a piece of text so you can solve a different task.
For transfer learning you generally have two steps. You use dataset X to pretrain your model. Then you use that pretrained model to carry that knowledge into solving dataset B. In this case, BERT has been pretrained on BookCorpus and English Wikipedia . The downstream task is what you care about which is solving a GLUE task or classifying product reviews.
The benefit of pretraining is that we don’t need much data in the downstream task to get amazing results.
Finetuning with PyTorch Lightning
In general, we can finetune with PyTorch Lightning using the following abstract approach:
For transfer learning we define two core parts inside the LightningModule.
- The pretrained model (ie: feature extractor)
- The finetune model.
You can think of the pretrained model as a feature extractor. This can allow you to represent objects or inputs in a much better way than say a boolean or some tabular mapping.
For instance if you have a collection of documents, you could run each through the pretrained model, and use the output vectors to compare documents to each other.
The finetune model can be arbitrarily complex. It could be a deep network, or it could be a simple Linear model or SVM.
Finetuning with BERT
Here we’ll use a pretrained BERT to finetune on a task called MNLI. This is really just trying to classify text into three categories. Here’s the LightningModule:
In this case we’re using the pretrained BERT from the huggingface library and adding our own simple linear classifier to classify a given text input into one of three classes.
However, we still need to define the validation loop which calculates our validation accuracy
And the test loop which calcualates our test accuracy
Finally, we define the optimizer and dataset we’ll operate on. This dataset should be the downstream dataset which you’re trying to solve.
The full LightningModule Looks like this.
Here we learned to use the Huggingface BERT as a feature extractor inside a LightningModule. This approach means you can leverage a really strong text representation to do things like:
- Sentiment analysis
- Suggested replies to chatbots
- Build recommendation engines using NLP
- Improve the Google Search algorithm
- Create embeddings for documents for similarity search
- Anything you can creatively think about!
Bio: William Falcon is an AI Researcher, startup founder, CTO, Google Deepmind Fellow, and current PhD AI research intern at Facebook AI.
Original. Reposted with permission.
- Pytorch Lightning vs PyTorch Ignite vs Fast.ai
- Attention Craving RNNS: Building Up To Transformer Networks
- 9 Tips For Training Lightning-Fast Neural Networks In Pytorch