Building Multimodal Models: Using the widedeep Pytorch package
This article gets you started on the open-source widedeep PyTorch framework developed by Javier Rodriguez Zaurin.
By Rajiv Shah, Data Scientist at Snorkel.ai
Image by the Author
Are you faced with modeling on larger datasets, multimodal data, or sophisticated targets such as multiclass, multitarget, or multitask? A flexible modular deep learning architecture can be well suited to those problems. Pytorch-widedeep is an open-source deep-learning package built for multimodal problems.
Widedeep was developed by Javier Rodriguez Zaurin and is a popular PyTorch package with over 600 Github stars. It is built to be easy to use, contains a modular architecture, and has been continually updated to contain the latest models like SAINT, Perceiver, and FastFormer. I found this package when I was looking into explainability for deep learning multimodal approaches. To get more data scientists familiar with widedeep, I wrote this post to introduce the package.
Want to jump into the code? Grab the companion notebooks and start using them!
This article covers the following topics:
- Is PyTorch WideDeep right for you?
- Preprocessing your data
- Defining a model
- Training time
- Explaining your model
- Getting predictions and saving your model
Is Pytorch WideDeep right for you?
PyTorch-widedeep is built for when you have multimodal data (wide) and want to use deep learning to find complex relationships in your data (deep). For example, predicting the value of a house based on images of the house, tabular data (e.g., number of rooms, floor area), and text data (e.g, a detailed description). With widedeep you can bring all those disparate types of data into one deep learning model.
Generally, you shouldn’t start modeling by jumping into deep learning. If you are just getting started, you are better off using an approach like gradient boosted machines, which typically performs better on tabular data. See Szilard for an overview of approaches on tabular data or Javier’s post for a comparison of GBM versus deep learning using widedeep. Nevertheless, teams like Pinterest or Lyft are moving to deep learning models for some applications.
The post focuses on widedeep which is my favorite open-source package for building multimodal packages. But the PyTorch ecosystem has lots of other great packages I would recommend including fast.ai, pytorch-tabular, and pytorch-forecasting.
Preprocessing your data
Widedeep has a couple of different types of preprocessors. They support many types of features, such as categorical or continuous, and then widedeep performs the preprocessing. There are options for scaling, setting the size of embeddings for categorical features, using a CLS token with transformers, and the overall functionality you would expect for preprocessing.
There are sensible default settings, so as a starting point, you can build a preprocessor using the defaults with the following code:
tab_preprocessor = TabPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)X_tab = tab_preprocessor.fit_transform(df)
Defining a model
Widedeep provides a lot of flexibility for specifying a model. There are four main components that compromise a Wide and Deep model: wide, deeptabular, deeptext, and deepimage.
Widedeep offers models for each of those components. For example, for deepimage there are pre-trained ResNet models available. For deeptabular there are a huge set of options available including:
TabFastFormer. To see dig deeper, check out the notebooks that focus on deeptabular models or the transformer models.
In this code snippet, I am using the FastFormer model. I can set various modeling parameters including the inputs, dropout, activation, whether to use batchnorm, and many many others. This provides an easy way to specify and review your model definition
If I was bringing in images or text, I would add them as additional inputs as deepimage or deeptext, in the WideDeep command. WideDeep is the model component that ties together all the other models for tackling multimodal problems.
tabfastformer = TabFastFormer( column_idx=tab_preprocessor.column_idx, continuous_cols=tab_preprocessor.continuous_cols, embed_input=tab_preprocessor.embeddings_input, n_blocks=2, n_heads=4)model = WideDeep(deeptabular=tabfastformer)
Let’s train the model. Widedeep is like other pytorch approaches that use a trainer to automate the model-building process. This is where you can specify your loss function, including a custom loss function, optimizer, learning rate scheduled, metrics, and much more.
After specifying the trainer, you have your traditional options around fit, transform, and predict what you would expect. Again, the abstraction and automation keep the code very clean. For someone that started deep learning many years ago, this code is so much easier to work with.
fasttab_model = Trainer(model, objective=”rmse”)fasttab_model.fit(X_tab=X_tab_train, target=y_train, n_epochs=50, batch_size=256, val_split=0.2)
Explaining your model
After the model is built, my next step is to use explainability tools to understand how the model is working. WideDeep provides hooks to get attention weights and it’s also easy to use well-known explainability libraries like Captum and Shap.
Let’s start with a permutation-based feature impact. This method allows us to understand the effect of the features which is known as global feature importance.
from captum.attr import FeaturePermutationfeature_perm = FeaturePermutation(model.deeptabular.eval()) attr_fic = feature_perm.attribute(X_tab_test,target=0) attr_fic
These results can easily be visualized:
Image by the Author
To get Shapley explanations via sampling using the captum library:
from captum.attr import ShapleyValueSamplingshapv = ShapleyValueSampling(model.deeptabular.eval()) shapv_attr_test = shapv.attribute(X_tab_test,target=0) shapv_attr_test
To get Shapley Integrated Gradient explanations using the shap library:
import shapexplainer = shap.GradientExplainer(model.deeptabular.eval(), torch.Tensor(background)) shap_values_gradient = explainer.shap_values(X_tab_test) shap_values_gradient
The shap library also lets you visualize an individual explanation:
The Shap explanations focus on the effect of features on individual observations (known as a local explanation). You can also sum up the Shapley values to get global feature importance as well.
This is just a start. There are additional explanation techniques you can run using Captum, here is an enormous plot with ten techniques against one dataset. You can try all these methods out in the deep dive companion notebook on explainability.
Image by the Author
One caveat when using categorical features in neural networks is explainability varies by method. Sometimes your model does not contain the actual value (it uses a label instead) when training, so techniques like Integrated Gradients can not show the effect of a categorical feature. It’s possible for networks to contain actual values, but it’s something that needs to be considered during model design. The Captum package has a more detailed explanation of the limits of the integrated gradients method.
Finally, the widedeep supports exporting attention weights. You can then process them for insights. The advantage of attention weights is they are built during model training and require little computation for getting insights. However, I would not rely on just attention weights for explaining a model. I have worked with models where attention weights were not as useful as model agnostic techniques like permutation-based importance.
Getting predictions and saving your models
Widedeep follows general conventions for performing predictions and saving your models. Here is a quick snippet on getting prediction probabilities:
results = fasttab_model.predict_proba(X_tab=X_tab) results
You can serialize both the preprocessor and model using pickle.
import pickle#save preprocessing with open('tab_preproc.pkl', 'wb') as wp:pickle.dump(tab_preprocessor, dp)#save the model and training history fasttab_model.save(path ="",model_filename="model_saved.pt")
Summary and next steps
You now have a good introduction to the open-source widedeep package. You should be able to know when to try widedeep. You also understand how widedeep works within a typical modeling workflow around preprocessing, defining a model, training a model, explaining a model, and then getting predictions.
Please check out the two companion notebooks to start diving deeper into what was covered in this post. You can even run these on Google Colab in your browser, so get started now!
Bio: Rajiv Shah is a Data Scientist at Snorkel.ai. Previously, Rajiv has been part of data science teams at DataRobot, Caterpillar and State Farm. He enjoys data science and spends time mentoring data scientists, speaking at events, and having fun with blog posts. He has a Ph.D. from the University of Illinois at Urbana Champaign.
Original. Reposted with permission.
- Optimization 101 for Data Scientists
- The Value of Semi-Supervised Machine Learning
- Deep Learning with R + Keras