Advanced PyTorch Lightning with TorchMetrics and Lightning Flash

In this tutorial we will be diving deeper into two additional tools you should be using: TorchMetrics and Lightning Flash. TorchMetrics unsurprisingly provides a modular approach to define and track useful metrics across batches and devices, while Lightning Flash offers a suite of functionality facilitating more efficient transfer learning and data handling, and a recipe book of state-of-the-art approaches to typical deep learning problems.


Just to recap from our last post on Getting Started with PyTorch Lightning, in this tutorial we will be diving deeper into two additional tools you should be using: TorchMetrics and Lightning Flash.

TorchMetrics unsurprisingly provides a modular approach to define and track useful metrics across batches and devices, while Lightning Flash offers a suite of functionality facilitating more efficient transfer learning and data handling, and a recipe book of state-of-the-art approaches to typical deep learning problems.

We’ll start by adding a few useful classification metrics to the MNIST example we started with earlier. We’ll also swap out the PyTorch Lightning Trainer object with a Flash Trainer object, which will make it easier to perform transfer learning on a new classification problem. We’ll then train our classifier on a new dataset, CIFAR10, which we’ll use as the basis for a transfer learning example to CIFAR100.



First things first, and that’s ensuring that we have all needed packages installed. If you already followed the install instructions from the "Getting Started" tutorial and now check your virtual environment contents with pip freeze, you’ll notice that you probably already have TorchMetrics installed. If not, install both TorchMetrics and Lightning Flash with the following:

pip install torchmetrics
pip install lightning-flash
pip install lightning-flash[image]


Next we’ll modify our training and validation loops to log the F1 score and Area Under the Receiver Operator Characteristic Curve (AUROC) as well as accuracy. We’ll remove the (deprecated) accuracy from pytorch_lightning.metrics and the similar sklearn function from the validation_epoch_end callback in our model, but first let’s make sure to add the necessary imports at the top.

# ...
import pytorch_lightning as pl

# replace: from pytorch_lightning.metrics import functional as FM
# with the one below
import torchmetrics

# import lightning_flash, which we’ll use later
import flash
from flash.image import ImageClassifier, ImageClassificationData                
# ...


Next, remove the lines we used previously to calculate accuracy:

# ...
# in training_step
y_pred = output.argmax(-1).cpu().numpy()
y_tgt = y.cpu().numpy()

# remove the line below line:
# accuracy = sklearn.metrics.accuracy_score(y_tgt, y_pred)
self.log("train loss", loss)

# and this one: self.log("train accuracy", accuracy)
return loss
# ...



# ...
# in validation_epoch_end
y_preds = preds.cpu().numpy()
y_tgts = tgts.cpu().numpy()
# remove the lines below:
# fm_accuracy = FM.accuracy(outputs, tgts)
# accuracy = sklearn.metrics.accuracy_score(y_tgts, y_preds)
# self.log("val_accuracy", accuracy)
self.log("val_loss", loss)
# ...


Now, we could just replace what we removed with the equivalent TorchMetrics functional implementation for calculating accuracy and leave it at that:

# ...
# in training_step

accuracy = torchmetrics.functional.accuracy(y_pred, y_tgt)
f1_score = torchmetrics.functional.f1(y_pred, y_tgt,
auroc = torchmetrics.functional.auroc(y_pred, y_tgt,
number_classes=10, average="micro")
self.log("train_loss", loss)
self.log("train_accuracy", accuracy)
self.log("train_f1", f1_score)
self.log("train_auroc", auroc)
return loss
# ...



# ...
accuracy = torchmetrics.functional.accuracy(outputs, tgts)
f1_score = torchmetrics.functional.f1(outputs, tgts,
auroc = torchmetrics.functional.auroc(outputs, tgts,
number_classes=10, average="micro")
self.log("val_accuracy", accuracy)
self.log("val_f1_score", f1_score)
self.log("val_auroc", auroc)
self.log("val_loss", loss)
# ...


However, there are additional advantages to using the class-based, modular versions of metrics.

With class-based metrics, we can continuously accumulate data while running training and validation, and compute the result at the end. This is convenient and efficient on a single device, but it really becomes useful with multiple devices as the metrics modules can automatically synchronize between multiple devices.

We’ll initialize our metrics in the __init__ function, and add calls for each metric in the training and validation steps.

class MyClassifier(pl.LightningModule):
    def __init__(self, dim=28, activation=nn.ReLU()):
        super(MyClassifier, self).__init__()
        self.image_dim = dim
        self.hid_dim = 128
        self.num_classes = 10
        self.act = activation
        # add metrics
        self.train_acc = torchmetrics.Accuracy()
        self.train_f1 = torchmetrics.F1(number_classes=10,
        self.train_auroc = torchmetrics.AUROC(number_classes=10,
        self.val_acc = torchmetrics.Accuracy()
        self.val_f1 = torchmetrics.F1(number_classes=10,
        self.val_auroc = torchmetrics.AUROC(number_classes=10,

        # __init__ function continues
        # ...


The metrics modules defined in __init__ will be called during training_step and validation_step, and we’ll compute them at the end of each training and validation epoch.

In the step function, we’ll call our metrics objects to accumulate metrics data throughout training and validation epochs. We can either call the “forward” method for each metrics object to accumulate data while also returning the value for the current batch, or we can call the “update” method to silently accumulate metrics data.

def training_step(self, batch, batch_index):
    x, y = batch
    output = self.forward(x)
    loss = F.nll_loss(F.log_softmax(output, dim = -1), y)
    y_pred = output.softmax(dim=-1)
    y_tgt = y
    # accumulate and return metrics for logging
    acc = self.train_acc(y_pred, y_tgt)
    f1 = self.train_f1(y_pred, y_tgt)
    # just accumulate
    self.train_auroc.update(y_pred, y_tgt)
    self.log("train_loss", loss)
    self.log("train_accuracy", acc)
    self.log("train_f1", f1)
    return loss
def validation_step(self, batch, batch_idx):
    x, y = batch
    output = self.forward(x)
    loss = F.cross_entropy(output, y)
    pred = output.softmax(dim=-1)
    self.val_acc.update(pred, y)
    self.val_f1.update(pred, y)
    self.val_auroc.update(pred, y)
    return loss


We’ll re-write validation_epoch_end and overload training_epoch_end to compute and report metrics for the entire epoch at once.

def training_epoch_end(self, training_step_outputs):
    # compute metrics
    train_accuracy = self.train_acc.compute()
    train_f1 = self.train_f1.compute()
    train_auroc = self.train_auroc.compute()
    # log metrics
    self.log("epoch_train_accuracy", train_accuracy)
    self.log("epoch_train_f1", train_f1)
    # reset all metrics
    print(f"\ntraining accuracy: {train_accuracy:.4}, "\
    f"f1: {train_f1:.4}, auroc: {train_auroc:.4}")

def validation_epoch_end(self, validation_step_outputs):
    # compute metrics
    val_loss = torch.tensor(validation_step_outputs).mean()
    val_accuracy = self.val_acc.compute()
    val_f1 = self.val_f1.compute()
    val_auroc = self.val_auroc.compute()
    # log metrics
    self.log("val_accuracy", val_accuracy)
    self.log("val_loss", val_loss)
    self.log("val_f1", val_f1)
    self.log("val_auroc", val_auroc)
    # reset all metrics
    print(f"\nvalidation accuracy: {val_accuracy:.4} "\
    f"f1: {val_f1:.4}, auroc: {val_auroc:.4}")


With those few changes, we can take advantage of more than 25 different metrics implemented in TorchMetrics, or sub-class the torchmetrics.Metrics class and implement our own. Keep in mind though that there are simpler ways to implement training for common tasks like image classification than sub-classing the LightningModule class.


Lightning Flash

Like a set of Russian nesting dolls of deep learning abstraction libraries, Lightning Flash adds further abstractions and simplification on top of PyTorch Lightning. In fact we can train an image classification task in only 7 lines. We’ll use the CIFAR10 dataset and a classification model based on the ResNet18 backbone built into Lightning Flash. Then we’ll show how the model backbone can be repurposed for classifying a new dataset, CIFAR100,

While Lightning Flash is very much still under active development and has plenty of sharp edges, you can already put together certain workflows with very little code, and there’s even a “no-code” capability they call Flash Zero. For our purposes, we can put together a transfer learning workflow with less than 20 lines.

First, we’ll conduct training on the CIFAR10 dataset with 8 lines of code. We take advantage of the ImageClassifier class and its built-in backbone architectures, as well as the ImageClassificationData class to replace both training and validation dataloaders.

metrics_10 = [torchmetrics.Accuracy(), \
    torchmetrics.F1(num_classes=10, average="micro")]
validation_interval = 1.0
train_dataset = CIFAR10(os.getcwd(), download=True, \
    train=True) #, transform=transforms.ToTensor())
val_dataset = CIFAR10(os.getcwd(), download=True, \
    train=False) #, transform=transforms.ToTensor())
datamodule = ImageClassificationData.from_datasets(
model = ImageClassifier(backbone="resnet18", \
    num_classes=10, metrics=metrics_10)
trainer = flash.Trainer(max_epochs=25, \
val_check_interval=validation_interval, gpus=1), datamodule=datamodule)


After that we can train on a new image classification task, the CIFAR100 dataset, which has fewer examples per class, by re-using the feature extraction backbone of our previously trained model and transfer learning using the “freeze” method.

This strategy only updates the parameters on the new classification head, while leaving the backbone parameters unchanged.

train_dataset = CIFAR100(os.getcwd(), download=True, \
    train=True) #, transform=transforms.ToTensor())
val_dataset = CIFAR100(os.getcwd(), download=True, \
    train=False) #, transform=transforms.ToTensor())
metrics_100 = [torchmetrics.Accuracy(), \
    torchmetrics.F1(num_classes=100, average="micro")]
datamodule = ImageClassificationData.from_datasets(
model_2 = ImageClassifier(backbone=(model.backbone, 512),\
    num_classes=100, metrics=metrics_100)
trainer_2 = flash.Trainer(max_epochs=15, \
    val_check_interval=validation_interval, gpus=1)
trainer_2.finetune(model_2, datamodule=datamodule,\


This type of parameter re-application to new tasks is at the core of transfer learning and saves time and compute, and the costs associated with both. Given that developer time is even more valuable than compute time, the concise programming style of Lightning Flash can be well worth the investment of learning a few new API patterns to use it.

Some of the most practical deep learning advice can be boiled down to “don’t be a hero,” i.e. don’t reinvent the wheel and ignore all the convenient tools like Flash that can make your life easier.

Speaking of easier, there’s one more way to train models with Flash that we’d be remiss not to mention. With Flash Zero, you can call Lightning Flash directly from the command line to train common deep learning tasks with built-in SOTA models. Flash Zero also has plenty of sharp edges and if you want to adapt it to your needs, be ready to work on a few pull request contributions to the PyTorch Lightning project.

For example, the following is a modified example from the Flash Zero documentation. If you look at the original version (as of this writing), you’ll likely notice right away that there is a typo in the command line argument for downloading the hymenoptera dataset: the download output filename is missing its extension. The fixed version below downloads the hymenoptera dataset and then trains a classifier with the ResNet18 backbone for 10 epochs:

curl \

flash image_classification --trainer.max_epochs 10 –model.backbone \
    resnet18 from_folders --train_folder \


A documentation typo is a pretty minor error (and also a welcome opportunity for you to open your first pull request to the project!), but it is a good sign that things are changing quickly at the PyTorch Lightning and Lightning Flash projects.

Expect development to continue at a rapid pace as the project scales. That means it’s probably a good idea to use static version numbers when setting up your dependencies on a new project, to avoid breaking changes as Lightning code is updated. At the same time, this presents an opportunity to shape the future of the project to meet your specific R&D needs, either by pull requests, contributing comments, or opening issues on the project’s GitHub channel.

In these PyTorch Lightning tutorial posts we’ve seen how PyTorch Lightning can be used to simplify training of common deep learning tasks at multiple levels of complexity. By sub-classing the LightningModule, we were able to define an effective image classifier with a model that takes care of training, validation, metrics, and logging, greatly simplifying any need to write an external training loop. The model also used a PyTorch Lightning Trainer object that made switching the entire training flow over to the GPU a breeze. Building models from Lightning Modules is a great way to gain utility without sacrificing control.

By using Lightning Flash, we then built a transfer learning workflow in just 15 lines of code, excepting imports. For problems with known solutions and an established state-of-the-art, you can save a lot of time by taking advantage of built-in architectures and training infrastructure with Flash!

Finally, we had a glimpse at Flash Zero for no-code training from the command line. No-code is an increasingly popular approach to machine learning, and although begrudged by engineers, no-code has a lot of promise. Currently developing rapidly, Flash Zero is set to become a powerful way to apply the best-engineered solutions out-of-the-box, so that machine learning and data scientists can focus on the science part of their job title.

Bio: Kevin Vu manages Exxact Corp blog and works with many of its talented authors who write about different aspects of Deep Learning.

Original. Reposted with permission.