!pip install kornia
!pip install kornia-rsAugmentation Sequential
Basic
2D
Data augmentation
kornia.augmentation
In this tutorial we will show how we can quickly perform data augmentation for various tasks (segmentation, detection, regression) using the features provided by the
kornia.augmentation.AugmentationSequential API.
Install and get data
We install Kornia and some dependencies, and download a simple data sample
import io
import requests
def download_image(url: str, filename: str = "") -> str:
filename = url.split("/")[-1] if len(filename) == 0 else filename
# Download
bytesio = io.BytesIO(requests.get(url).content)
# Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
url = "https://raw.githubusercontent.com/kornia/data/main/panda.jpg"
download_image(url)import cv2
import kornia as K
import numpy as np
import torch
from kornia.augmentation import AugmentationSequential
from kornia.geometry import bbox_to_mask
from matplotlib import pyplot as pltdef plot_resulting_image(img, bbox, keypoints, mask):
img = img * mask
img_array = K.tensor_to_image(img.mul(255).byte()).copy()
img_draw = cv2.polylines(img_array, bbox.numpy(), isClosed=True, color=(255, 0, 0))
for k in keypoints[0]:
img_draw = cv2.circle(img_draw, tuple(k.numpy()[:2]), radius=6, color=(255, 0, 0), thickness=-1)
return img_draw
img_tensor = K.io.load_image("panda.jpg", K.io.ImageLoadType.RGB32)[None, ...] # BxCxHxW
h, w = img_tensor.shape[-2:]
plt.axis("off")
plt.imshow(K.tensor_to_image(img_tensor))
plt.show()
Define Augmentation Sequential and Different Labels
aug_list = AugmentationSequential(
K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
K.augmentation.RandomAffine(360, [0.1, 0.1], [0.7, 1.2], [30.0, 50.0], p=1.0),
K.augmentation.RandomPerspective(0.5, p=1.0),
data_keys=["input", "bbox", "keypoints", "mask"],
same_on_batch=False,
)
bbox = torch.tensor([[[[355, 10], [660, 10], [660, 250], [355, 250]]]])
keypoints = torch.tensor([[[465, 115], [545, 116]]])
mask = bbox_to_mask(torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]]), w, h)[None].float()
img_out = plot_resulting_image(img_tensor, bbox[0], keypoints, mask[0])
plt.axis("off")
plt.imshow(img_out)
plt.show()
Forward Computations
out_tensor = aug_list(img_tensor, bbox.float(), keypoints.float(), mask)
img_out = plot_resulting_image(
out_tensor[0][0],
out_tensor[1].int(),
out_tensor[2].int(),
out_tensor[3][0],
)
plt.axis("off")
plt.imshow(img_out)
plt.show()
Inverse Transformations
out_tensor_inv = aug_list.inverse(*out_tensor)
img_out = plot_resulting_image(
out_tensor_inv[0][0],
out_tensor_inv[1].int(),
out_tensor_inv[2].int(),
out_tensor_inv[3][0],
)
plt.axis("off")
plt.imshow(img_out)
plt.show()