%%capture
!pip install kornia
!pip install kornia-rsData Augmentation 2D
Basic
2D
Data augmentation
kornia.augmentation
A show case of the Data Augmentation operation available on Kornia for images.
Just a simple examples showing the Augmentations available on Kornia.
For more information check the docs: https://kornia.readthedocs.io/en/latest/augmentation.module.html
import kornia
import kornia.utils
import matplotlib.pyplot as plt
import numpy as np
import torch
from kornia.augmentation import (
CenterCrop,
ColorJiggle,
ColorJitter,
PadTo,
RandomAffine,
RandomBoxBlur,
RandomBrightness,
RandomChannelShuffle,
RandomContrast,
RandomCrop,
RandomCutMixV2,
RandomElasticTransform,
RandomEqualize,
RandomErasing,
RandomFisheye,
RandomGamma,
RandomGaussianBlur,
RandomGaussianNoise,
RandomGrayscale,
RandomHorizontalFlip,
RandomHue,
RandomInvert,
RandomJigsaw,
RandomMixUpV2,
RandomMosaic,
RandomMotionBlur,
RandomPerspective,
RandomPlanckianJitter,
RandomPlasmaBrightness,
RandomPlasmaContrast,
RandomPlasmaShadow,
RandomPosterize,
RandomResizedCrop,
RandomRGBShift,
RandomRotation,
RandomSaturation,
RandomSharpness,
RandomSolarize,
RandomThinPlateSpline,
RandomVerticalFlip,
)
from PIL import ImageLoad an Image
The augmentations expects an image with shape BxCxHxW
import io
import requests
def download_image(url: str, filename: str = "") -> str:
filename = url.split("/")[-1] if len(filename) == 0 else filename
# Download
bytesio = io.BytesIO(requests.get(url).content)
# Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
url = "https://raw.githubusercontent.com/kornia/data/main/panda.jpg"
download_image(url)img = (kornia.image_to_tensor(np.array(Image.open("panda.jpg").convert("RGB"))).float() / 255.0)[None, ...]def plot_tensor(data, title=""):
b, c, h, w = data.shape
fig, axes = plt.subplots(1, b, dpi=150, subplot_kw={"aspect": "equal"})
if b == 1:
axes = [axes]
for idx, ax in enumerate(axes):
ax.imshow(kornia.utils.tensor_to_image(data[idx, ...]))
ax.set_ylim(h, 0)
ax.set_xlim(0, w)
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
fig.suptitle(title)
plt.show()plot_tensor(img, "panda")2D transforms
Sometimes you may wish to apply the exact same transformations on all the elements in one batch. Here, we provided a same_on_batch keyword to all random generators for you to use. Instead of an element-wise parameter generating, it will generate exact same parameters across the whole batch.
# Create a batched input
num_samples = 2
inpt = img.repeat(num_samples, 1, 1, 1)Intensity
Random Planckian Jitter
randomplanckianjitter = RandomPlanckianJitter("blackbody", same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomplanckianjitter(inpt), "Planckian Jitter")Random Plasma Shadow
randomplasmashadow = RandomPlasmaShadow(
roughness=(0.1, 0.7),
shade_intensity=(-1.0, 0.0),
shade_quantity=(0.0, 1.0),
same_on_batch=False,
keepdim=False,
p=1.0,
)
plot_tensor(randomplasmashadow(inpt), "Plasma Shadow")Random Plasma Brightness
randomplasmabrightness = RandomPlasmaBrightness(
roughness=(0.1, 0.7),
intensity=(0.0, 1.0),
same_on_batch=False,
keepdim=False,
p=1.0,
)
plot_tensor(randomplasmabrightness(inpt), "Plasma Brightness")Random Plasma Contrast
randomplasmacontrast = RandomPlasmaContrast(roughness=(0.1, 0.7), same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomplasmacontrast(inpt), "Plasma Contrast")Color Jiggle
colorjiggle = ColorJiggle(0.3, 0.3, 0.3, 0.3, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(colorjiggle(inpt), "Color Jiggle")Color Jitter
colorjitter = ColorJitter(0.3, 0.3, 0.3, 0.3, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(colorjitter(inpt), "Color Jitter")Random Box Blur
randomboxblur = RandomBoxBlur((21, 5), "reflect", same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomboxblur(inpt), "Box Blur")Random Brightness
randombrightness = RandomBrightness(brightness=(0.8, 1.2), clip_output=True, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randombrightness(inpt), "Random Brightness")Random Channel Shuffle
randomchannelshuffle = RandomChannelShuffle(same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomchannelshuffle(inpt), "Random Channel Shuffle")Random Contrast
randomcontrast = RandomContrast(contrast=(0.8, 1.2), clip_output=True, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomcontrast(inpt), "Random Contrast")Random Equalize
randomequalize = RandomEqualize(same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomequalize(inpt), "Random Equalize")Random Gamma
randomgamma = RandomGamma((0.2, 1.3), (1.0, 1.5), same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomgamma(inpt), "Random Gamma")Random Grayscale
randomgrayscale = RandomGrayscale(same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomgrayscale(inpt), "Random Grayscale")Random Gaussian Blur
randomgaussianblur = RandomGaussianBlur((21, 21), (0.2, 1.3), "reflect", same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomgaussianblur(inpt), "Random Gaussian Blur")Random Gaussian Noise
randomgaussiannoise = RandomGaussianNoise(mean=0.2, std=0.7, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomgaussiannoise(inpt), "Random Gaussian Noise")Random Hue
randomhue = RandomHue((-0.2, 0.4), same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomhue(inpt), "Random Hue")Random Motion Blur
randommotionblur = RandomMotionBlur((7, 7), 35.0, 0.5, "reflect", "nearest", same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randommotionblur(inpt), "Random Motion Blur")Random Posterize
randomposterize = RandomPosterize(bits=3, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomposterize(inpt), "Random Posterize")Random RGB Shift
randomrgbshift = RandomRGBShift(
r_shift_limit=0.5,
g_shift_limit=0.5,
b_shift_limit=0.5,
same_on_batch=False,
keepdim=False,
p=1.0,
)
plot_tensor(randomrgbshift(inpt), "Random RGB Shift")Random Saturation
randomsaturation = RandomSaturation((1.0, 1.0), same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomsaturation(inpt), "Random Saturation")Random Sharpness
randomsharpness = RandomSharpness((0.5, 1.0), same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomsharpness(inpt), "Random Sharpness")Random Solarize
randomsolarize = RandomSolarize(0.3, 0.1, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomsolarize(inpt), "Random Solarize")Geometric
Center Crop
centercrop = CenterCrop(
150,
resample="nearest",
cropping_mode="resample",
align_corners=True,
keepdim=False,
p=1.0,
)
plot_tensor(centercrop(inpt), "Center Crop")Pad To
padto = PadTo((500, 500), "constant", 1, keepdim=False)
plot_tensor(padto(inpt), "Pad To")Random Affine
randomaffine = RandomAffine(
(-15.0, 5.0),
(0.3, 1.0),
(0.4, 1.3),
0.5,
resample="nearest",
padding_mode="reflection",
align_corners=True,
same_on_batch=False,
keepdim=False,
p=1.0,
)
plot_tensor(randomaffine(inpt), "Random Affine")Random Crop
randomcrop = RandomCrop(
(150, 150),
10,
True,
1,
"constant",
"nearest",
cropping_mode="resample",
same_on_batch=False,
align_corners=True,
keepdim=False,
p=1.0,
)
plot_tensor(randomcrop(inpt), "Random Crop")Random Erasing
randomerasing = RandomErasing(
scale=(0.02, 0.33),
ratio=(0.3, 3.3),
value=1,
same_on_batch=False,
keepdim=False,
p=1.0,
)
plot_tensor(randomerasing(inpt), "Random Erasing")Random Elastic Transform
randomelastictransform = RandomElasticTransform(
(27, 27),
(33, 31),
(0.5, 1.5),
align_corners=True,
padding_mode="reflection",
same_on_batch=False,
keepdim=False,
p=1.0,
)
plot_tensor(randomelastictransform(inpt), "Random Elastic Transform")Random Fish Eye
c = torch.tensor([-0.3, 0.3])
g = torch.tensor([0.9, 1.0])
randomfisheye = RandomFisheye(c, c, g, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomfisheye(inpt), "Random Fish Eye")Random Horizontal Flip
randomhorizontalflip = RandomHorizontalFlip(same_on_batch=False, keepdim=False, p=0.7)
plot_tensor(randomhorizontalflip(inpt), "Random Horizontal Flip")Random Invert
randominvert = RandomInvert(same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randominvert(inpt), "Random Invert")Random Perspective
randomperspective = RandomPerspective(0.5, "nearest", align_corners=True, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomperspective(inpt), "Random Perspective")Random Resized Crop
randomresizedcrop = RandomResizedCrop(
(200, 200),
(0.4, 1.0),
(2.0, 2.0),
"nearest",
align_corners=True,
cropping_mode="resample",
same_on_batch=False,
keepdim=False,
p=1.0,
)
plot_tensor(randomresizedcrop(inpt), "Random Resized Crop")Random Rotation
randomrotation = RandomRotation(15.0, "nearest", align_corners=True, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomrotation(inpt), "Random Rotation")Random Vertical Flip
randomverticalflip = RandomVerticalFlip(same_on_batch=False, keepdim=False, p=0.6, p_batch=1.0)
plot_tensor(randomverticalflip(inpt), "Random Vertical Flip")Random Thin Plate Spline
randomthinplatespline = RandomThinPlateSpline(0.6, align_corners=True, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomverticalflip(inpt), "Random Thin Plate Spline")Mix
Random Cut Mix
randomcutmixv2 = RandomCutMixV2(4, (0.2, 0.9), 0.1, same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randomcutmixv2(inpt), "Random Cut Mix")Random Mix Up
randommixupv2 = RandomMixUpV2((0.1, 0.9), same_on_batch=False, keepdim=False, p=1.0)
plot_tensor(randommixupv2(inpt), "Random Mix Up")Random Mosaic
randommosaic = RandomMosaic(
(250, 125),
(4, 4),
(0.3, 0.7),
align_corners=True,
cropping_mode="resample",
padding_mode="reflect",
resample="nearest",
keepdim=False,
p=1.0,
)
plot_tensor(randommosaic(inpt), "Random Mosaic")Random Jigsaw
# randomjigsaw = RandomJigsaw((2, 2), ensure_perm=False, same_on_batch=False, keepdim=False, p=1.0)
# plot_tensor(randomjigsaw(inpt), "Random Jigsaw")