%%capture
!pip install kornia
!pip install kornia-rs
!pip install pytorch_lightning torchmetricsKornia and PyTorch Lightning GPU data augmentation
Basic
Data augmentation
Pytorch lightning
kornia.augmentation
In this tutorial we show how one can combine both Kornia and PyTorch Lightning to perform data augmentation to train a model using CPUs and GPUs in batch mode without additional effort.
Install Kornia and PyTorch Lightning
We first install Kornia and PyTorch Lightning
Import the needed libraries
import os
import kornia as K
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from PIL import Image
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10Define Data Augmentations module
class DataAugmentation(nn.Module):
"""Module to perform data augmentation using Kornia on torch tensors."""
def __init__(self, apply_color_jitter: bool = False) -> None:
super().__init__()
self._apply_color_jitter = apply_color_jitter
self._max_val: float = 255.0
self.transforms = nn.Sequential(K.enhance.Normalize(0.0, self._max_val), K.augmentation.RandomHorizontalFlip(p=0.5))
self.jitter = K.augmentation.ColorJitter(0.5, 0.5, 0.5, 0.5)
@torch.no_grad() # disable gradients for effiency
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_out = self.transforms(x) # BxCxHxW
if self._apply_color_jitter:
x_out = self.jitter(x_out)
return x_outDefine a Pre-processing model
class PreProcess(nn.Module):
"""Module to perform pre-process using Kornia on torch tensors."""
def __init__(self) -> None:
super().__init__()
@torch.no_grad() # disable gradients for effiency
def forward(self, x: Image) -> torch.Tensor:
x_tmp: np.ndarray = np.array(x) # HxWxC
x_out: torch.Tensor = K.image_to_tensor(x_tmp, keepdim=True) # CxHxW
return x_out.float()Define PyTorch Lightning model
class CoolSystem(pl.LightningModule):
def __init__(self):
super().__init__()
# not the best model...
self.l1 = torch.nn.Linear(3 * 32 * 32, 10)
self.preprocess = PreProcess()
self.transform = DataAugmentation()
self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
# REQUIRED
x, y = batch
x_aug = self.transform(x) # => we perform GPU/Batched data augmentation
logits = self.forward(x_aug)
loss = F.cross_entropy(logits, y)
self.log("train_acc_step", self.accuracy(logits.argmax(1), y))
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
logits = self.forward(x)
self.log("val_acc_step", self.accuracy(logits.argmax(1), y))
return F.cross_entropy(logits, y)
def test_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
logits = self.forward(x)
acc = self.accuracy(logits.argmax(1), y)
self.log("test_acc_step", acc)
return acc
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
return torch.optim.Adam(self.parameters(), lr=0.0004)
def prepare_data(self):
CIFAR10(os.getcwd(), train=True, download=True, transform=self.preprocess)
CIFAR10(os.getcwd(), train=False, download=True, transform=self.preprocess)
def train_dataloader(self):
# REQUIRED
dataset = CIFAR10(os.getcwd(), train=True, download=False, transform=self.preprocess)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
return loader
def val_dataloader(self):
dataset = CIFAR10(os.getcwd(), train=True, download=False, transform=self.preprocess)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
return loader
def test_dataloader(self):
dataset = CIFAR10(os.getcwd(), train=False, download=False, transform=self.preprocess)
loader = DataLoader(dataset, batch_size=16, num_workers=1)
return loaderRun training
from pytorch_lightning import Trainer
# init model
model = CoolSystem()
# Initialize a trainer
accelerator = "cpu" # can be 'gpu'
trainer = Trainer(accelerator=accelerator, max_epochs=1, enable_progress_bar=False)
# Train the model ⚡
trainer.fit(model)GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
| Name | Type | Params
--------------------------------------------------
0 | l1 | Linear | 30.7 K
1 | preprocess | PreProcess | 0
2 | transform | DataAugmentation | 0
3 | accuracy | MulticlassAccuracy | 0
--------------------------------------------------
30.7 K Trainable params
0 Non-trainable params
30.7 K Total params
0.123 Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=1` reached.
Test the model
trainer.test(model)
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_acc_step 0.10000000149011612
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'test_acc_step': 0.10000000149011612}]
Visualize
# # Start tensorboard.
# %load_ext tensorboard
# %tensorboard --logdir lightning_logs/