%%capture
%matplotlib inline
!pip install kornia
!pip install kornia-rs
!pip install opencv-pythonData Augmentation Semantic Segmentation
Basic
2D
Segmentation
Data augmentation
kornia.augmentation
In this tutorial we will show how we can quickly perform data augmentation for semantic segmentation using the
kornia.augmentation API.
Install and get data
We install Kornia and some dependencies, and download a simple data sample
import io
import requests
def download_image(url: str, filename: str = "") -> str:
filename = url.split("/")[-1] if len(filename) == 0 else filename
# Download
bytesio = io.BytesIO(requests.get(url).content)
# Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
url = "https://github.com/kornia/data/raw/main/causevic16semseg3.png"
download_image(url)'causevic16semseg3.png'
# import the libraries
import kornia as K
import matplotlib.pyplot as plt
import torch
import torch.nn as nnDefine Augmentation pipeline
We define a class to define our augmentation API using an nn.Module
class MyAugmentation(nn.Module):
def __init__(self):
super().__init__()
# we define and cache our operators as class members
self.k1 = K.augmentation.ColorJitter(0.15, 0.25, 0.25, 0.25)
self.k2 = K.augmentation.RandomAffine([-45.0, 45.0], [0.0, 0.15], [0.5, 1.5], [0.0, 0.15])
def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# 1. apply color only in image
# 2. apply geometric tranform
img_out = self.k2(self.k1(img))
# 3. infer geometry params to mask
# TODO: this will change in future so that no need to infer params
mask_out = self.k2(mask, self.k2._params)
return img_out, mask_outLoad the data and apply the transforms
def load_data(data_path: str) -> torch.Tensor:
data_t: torch.Tensor = K.io.load_image(data_path, K.io.ImageLoadType.RGB32)[None, ...] # BxCxHxW
img, labels = data_t[..., :571], data_t[..., 572:]
return img, labels
# load data (B, C, H, W)
img, labels = load_data("causevic16semseg3.png")
# create augmentation instance
aug = MyAugmentation()
# apply the augmenation pipelone to our batch of data
img_aug, labels_aug = aug(img, labels)
# visualize
img_out = torch.cat([img, labels], dim=-1)
plt.imshow(K.tensor_to_image(img_out))
plt.axis("off")
# generate several samples
num_samples: int = 10
for img_id in range(num_samples):
# generate data
img_aug, labels_aug = aug(img, labels)
img_out = torch.cat([img_aug, labels_aug], dim=-1)
# save data
plt.figure()
plt.imshow(K.tensor_to_image(img_out))
plt.axis("off")
# plt.savefig(f"img_{img_id}.png", bbox_inches="tight")
plt.show()









