The Best Machine Learning Frameworks & Extensions for TensorFlow

Check out this curated list of useful frameworks and extensions for TensorFlow.

TensorFlow has a large ecosystem of libraries and extensions. If you’re a developer, you can easily add them into your ML work without having to build new functions.

In this article, we will explore some of the TensorFlow extensions that you can start using right away.

To start, let’s check out domain-specific pre-trained models from TensorFlow Hub.

Let’s get to it!


TensorFlow Hub

TensorFlow Hub is a repository with hundreds of trained and ready-to-use models. You can find models for:

  • natural language processing
  • object detection
  • image classification
  • style transfer
  • video action detection
  • sound classification
  • pitch recognition

To use a model, you first need to identify it at You’re going to need to check its documentation. For example, here are instructions to load this ImageNet classification model.

model = tf.keras.Sequential([

Models can be used as they are, or you can fine-tune them. The model’s documentation offers instructions on how to do this.

For example, we can fine-tune the above model by passing ‘trainable=True’ to ‘hub.kerasLayer’.

               trainable=True, arguments=dict(batch_norm_momentum=0.997))


TensorFlow Model Optimization Toolkit

This is a collection of tools that you can use to optimize models for execution and deployment.

Why is this important?

  • it reduces the latency of models on mobile devices,
  • it reduces the cost of cloud, because models become small enough for edge device deployment.

Optimizing models might lead to a reduction in accuracy. Depending on the problem, you’ll need to decide if a slightly less accurate model is worth the advantage of model optimization.

Optimization can be applied to pre-trained models from, as well as your own trained models. You can also download optimized models from

One of the techniques for model optimization is pruning. In this technique, unnecessary values in the weight tensor are eliminated. This results in smaller models, with accuracy that’s very close to the baseline model.

The first step in pruning a model is to define the pruning parameters.

Setting a sparsity of 50% means that 50% of the weights will be zeroed. The ‘PruningSchedule’ is responsible for controlling pruning during training.

from tensorflow_model_optimization.sparsity.keras import ConstantSparsity
pruning_params = {
    'pruning_schedule': ConstantSparsity(0.5, 0),
    'block_size': (1, 1),
    'block_pooling_type': 'AVG'

After that, you can prune the entire model using the above parameters.

from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude
model_to_prune = prune_low_magnitude(
        tf.keras.layers.Dense(128, activation='relu', input_shape=(X_train.shape[1],)),
        tf.keras.layers.Dense(1, activation='relu')
    ]), **pruning_params)

An alternative is to use quantization aware training that uses lower-precision, for example 8-bit instead of 32-bit float.

import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantize_model(model)

At this point, you’ll have a model that’s quantization aware, but not yet quantized.

After you compile and train the model, you can create the quantized model using the TFLite Converter.

converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

You can also quantize certain layers of the model.

The other model optimization strategy is weight clustering. In this technique, the number of unique weight values is reduced.


TensorFlow Recommenders

TensorFlow Recommenders (TFRS) is a library for building recommender system models.

You can use it for preparing data, formulating the model, training, evaluation, and deployment. This Notebook contains a full example of how to use TFRS.


TensorFlow Federated

TensorFlow Federated (TFF) is an open-source library for machine learning on decentralized data. In federated learning, devices can collaboratively learn from a shared model.

The model will be trained on a server using proxy data. Each device will then download the model and improve it using the data on that device.

What’s good about this approach is that sensitive user data is never uploaded to the server. One way this has been used is in phone keyboards.

TensorFlow Federated is made up of two layers:

  • Federated Learning (FL) API
  • Federated Core (FC) API

Using the Federated Learning (FL) API, developers can apply federated training and evaluation on existing TensorFlow models.

The Federated Core (FC) API is a system of low-level interfaces for writing federated algorithms.

If you’re interested, check out official TensorFlow Federated tutorials to learn more.


TensorFlow Graphics

To build more efficient neural network architectures, you can insert differentiable graphic layers.

Modeling geometric priors and constraints to neural networks leads to architectures that can be trained more robustly and efficiently.

The combination of computer graphics and computer vision lets us use unlabelled data in machine learning problems. Tensorflow Graphics provides a suite of differentiable graphics, geometry layers and 3D viewer functionalities.

Here’s an example of the output produced by a code snippet from the official docs.

import numpy as np
import tensorflow as tf
import trimesh

import tensorflow_graphics.geometry.transformation as tfg_transformation
from tensorflow_graphics.notebooks import threejs_visualization

# Download the mesh.
# Load the mesh.
mesh = trimesh.load("cow.obj")
mesh = {"vertices": mesh.vertices, "faces": mesh.faces}
# Visualize the original mesh.
threejs_visualization.triangular_mesh_renderer(mesh, width=400, height=400)
# Set the axis and angle parameters.
axis = np.array((0., 1., 0.))  # y axis.
angle = np.array((np.pi / 4.,))  # 45 degree angle.
# Rotate the mesh.
mesh["vertices"] = tfg_transformation.axis_angle.rotate(mesh["vertices"], axis,
# Visualize the rotated mesh.
threejs_visualization.triangular_mesh_renderer(mesh, width=400, height=400)

tensorflow graphics


TensorFlow Privacy 

This library is for training machine learning models with training data privacy. Some of the tutorials provided for this include:

  • training a language model with differential privacy
  • a convolutional neural network on MNIST with differential privacy

Differential privacy is expressed using epsilon and delta.



This is a library of models and datasets aimed at making deep learning more accessible and accelerate research in machine learning.


TensorFlow Probability

According to the official docs:

“TensorFlow Probability is a library for probabilistic reasoning and statistical analysis in TensorFlow”

You can use the library to encode domain knowledge, but it also has:

  • support for many probability distributions
  • tools for building deep probabilistic models
  • variational inference and Markov chain Monte Carlo
  • optimizers such as Nelder-Mead, BFGS, and SGLD

Here’s an example model based on the Bernoulli distribution:

model = tfp.glm.Bernoulli()
coeffs, linear_response, is_converged, num_iter =
    model_matrix=features[:, tf.newaxis],
    response=tf.cast(labels, dtype=tf.float32),


TensorFlow Extended (TFX)

TensorFlow Extended (TFX) is a platform that you can use to bring your machine learning pipeline to production.

Plus, using TensorFlow’s ModelServer lets you use a RESTful API to access your model.

Assuming you have it installed and configured, the server can be started by running:

$ tensorflow_model_server -- rest_api_port=8000 
                               -- model_config_file=models.config 
                               -- model_config_file_poll_wait_seconds=300

The API will be available on port 8000 on localhost. Setting up this server requires some knowledge of server administration.



TensorBoard is TensorFlow’s open-source visualization toolkit. You can use it as a callback in your model training in order to track the process. It can be used to track various metrics such as log loss and accuracy. TensorBoard also provides several tools that can be used for experimentation.  You can use it to:

  • visualize images
  • check model weights and biases
  • visualize the architecture of the model
  • see the performance of your application via profiling

just to mention a few.

Note: As an alternative, you can also track and visualize model training runs,  and version your models in Neptune.

For instance, here is how you can log your Keras experiments using Neptune.

PARAMS = {'lr': 0.01, 'epochs': 10}
neptune.create_experiment('model-training-run', params=PARAMS), y_train,


See Neptune TensorFlow/Keras integration


TensorFlow Agents

This library can be used for designing, implementing, and testing reinforcement learning algorithms. It provides modular components that are extensively tested. Components can be modified and extended.

This notebook shows how to train a DQN (Deep Q Networks) agent on the Cartpole environment. The initialization code looks like this:

import tensorflow as tf
from tf_agents.networks import q_network
from tf_agents.agents.dqn import dqn_agent

q_net = q_network.QNetwork(

agent = dqn_agent.DqnAgent(



Final thoughts

In this article, we explored several libraries that can be used to extend TensorFlow’s functionalities. Try using the code snippets I provided to familiarize yourself with the tools.

We talked about:

  • using pre-trained models from TensorFlow Hub,
  • optimizing your models using TensorFlow Model Optimization Toolkit,
  • building recommenders using TensorFlow Recommenders,
  • training models on decentralized data using TensorFlow Federated,
  • training in private mode with TensorFlow Privacy.

And that’s quite a lot, so choose one of these to start with, and go through the list to see if any tools fit your machine learning workflow.

Bio: Derrick Mwiti is a data scientist who has a great passion for sharing knowledge. He is an avid contributor to the data science community via blogs such as Heartbeat, Towards Data Science, Datacamp, Neptune AI, KDnuggets just to mention a few. His content has been viewed over a million times on the internet. Derrick is also an author and online instructor. He also trains and works with various institutions to implement data science solutions as well as to upskill their staff. You might want to check his Complete Data Science & Machine Learning Bootcamp in Python course.

Original. Reposted with permission.