%%capture
!pip install git+https://github.com/kornia/kornia.git
Adaptive Discriminator Augmentation Preset
Theory
The core idea of the Adaptive Discriminator Augmentation technique (ADA for short), is applying a series of image augmentations to real and generated images before passing them to the Discriminator to make its task harder
Augmentations are applied with a (global) probability p
(: whether we apply augmentations or not), the series of augmentations we apply are recommended to be selected so the undderlying distribution is still learnable
we restrict the Discriminator’s capacity to identify real images by dynamically updating p
, i.e: increase the probability of applying augmentations when the Discriminator is performing too well (dominating the Generator) and vice versa
downloading the images we will use to explore ADA and a batch of 16 images for a dummy training lap for a simple GAN
import io
import zipfile
import requests
def download_images():
# kornia panda illustration
= "https://raw.githubusercontent.com/kornia/data/main/panda.jpg"
jpg_url = "panda.jpg"
jpg_path
= requests.get(jpg_url)
r if r.status_code == 200:
with open(jpg_path, "wb") as f:
f.write(r.content)else:
print(f"failed to download panda image")
# a batch of 16 panda images
= "https://github.com/kornia/data/raw/main/presets/ada_pandas.zip"
zip_url = requests.get(zip_url)
r if r.status_code == 200:
with zipfile.ZipFile(io.BytesIO(r.content)) as zip_ref:
".")
zip_ref.extractall(else:
print(f"failed to download pandas batch of images")
download_images()
in this tutorial, we will primarily use kornia
and pytorch
import os
import kornia
import kornia.augmentation as K
import matplotlib.pyplot as plt
import numpy as np
import torch
from kornia.augmentation.presets.ada import AdaptiveDiscriminatorAugmentation
from torch import nn, optim
from tqdm import tqdm
= 12
seed
torch.manual_seed(seed) torch.cuda.manual_seed(seed)
= kornia.io.load_image("panda.jpg", kornia.io.ImageLoadType.RGB32) # C, H, W
panda_image = panda_image.unsqueeze(0).repeat(8, 1, 1, 1) # 8, C, H, W
panda_images
def plot_images(images_tensor, n_rows=4, figscale=1):
"""utility function to plot expects a B, C, H, W tensor and plot them as a grid"""
= images_tensor.shape
b, c, h, w assert len(images_tensor) % n_rows == 0
= len(images_tensor) // n_rows
n_cols = images_tensor.view(n_rows, n_cols, c, h, w).permute(2, 0, 3, 1, 4).reshape(c, h * n_rows, w * n_cols)
images_tensor =(h // figscale, w // figscale))
plt.figure(figsize
plt.imshow(kornia.tensor_to_image(images_tensor))"off")
plt.axis( plt.show()
4], n_rows=1, figscale=10) plot_images(panda_images[:
ADA forward pass results
we can use the AdaptiveDiscriminatorAugmentation
API with the default list of augmentations or using a custom list, starting with the default ones in this example
= AdaptiveDiscriminatorAugmentation() default_ada
= 0.5
default_ada.p = default_ada(panda_images)
augmented_panda_images =2, figscale=20) plot_images(augmented_panda_images, n_rows
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.49608037..1.4770097].
= AdaptiveDiscriminatorAugmentation(
custom_ada =1.0, keepdim=True),
K.RandomChannelShuffle(p=1.0, keepdim=True),
K.RandomEqualize(p=0.5),
K.RandomHorizontalFlip(p0.15, 0.25, 0.25, 0.25),
K.ColorJitter(
)= 0.5 custom_ada.p
= custom_ada(panda_images)
augmented_panda_images =2, figscale=20) plot_images(augmented_panda_images, n_rows
Training a simple Generative Adversarial Network (GAN) with ADA
note: this implementation isn’t meant or optimized to generate high quality images but rather tuned to demonstrate the ADA updates during training with minimal compute, memory and time
tho the following training loop may run on CPU in a relatively reasonable time, we highly recommend running on a GPU if available
class Generator(nn.Module):
def __init__(self, lat_dim=64, init_channels=64):
super().__init__()
self.lat_dim = lat_dim
self.init_channels = init_channels
self.fc = nn.Linear(lat_dim, init_channels * 8 * 1 * 2)
= self._block(init_channels * 8, init_channels * 4) # 1x2 -> 4x8
block1 = self._block(init_channels * 4, init_channels * 2) # 4x8 -> 16x32
block2 = self._block(init_channels * 2, init_channels) # 16x32 -> 64x128
block3
self.blocks = nn.Sequential(block1, block2, block3)
self.final_conv = nn.Conv2d(init_channels, 3, kernel_size=3, padding=1)
self.activation = nn.Tanh()
def _block(self, in_channels, out_channels):
return nn.Sequential(
=4, mode="bilinear", align_corners=True),
nn.Upsample(scale_factor=3, padding=1),
nn.Conv2d(in_channels, out_channels, kernel_size
nn.BatchNorm2d(out_channels),=True),
nn.ReLU(inplace
)
def forward(self, z):
= self.fc(z).view(z.size(0), self.init_channels * 8, 1, 2)
x = self.blocks(x)
x = self.final_conv(x)
x return self.activation(x)
class Discriminator(nn.Module):
def __init__(self, init_channels=64):
super().__init__()
self.init_channels = init_channels
self.initial_conv = nn.Conv2d(3, init_channels, kernel_size=3, padding=1)
= self._block(init_channels, init_channels * 2) # 64x128 -> 32x64
block1 = self._block(init_channels * 2, init_channels * 4) # 32x64 -> 16x32
block2
self.blocks = nn.Sequential(block1, block2)
self.classifier = nn.Sequential(
4, 4)),
nn.AdaptiveAvgPool2d((
nn.Flatten(),* 4 * 4 * 4, 1),
nn.Linear(init_channels
)
def _block(self, in_channels, out_channels):
return nn.Sequential(
=3, stride=2, padding=1),
nn.Conv2d(in_channels, out_channels, kernel_size
nn.BatchNorm2d(out_channels),0.2, inplace=True),
nn.LeakyReLU(
)
def forward(self, img):
= self.initial_conv(img)
x = self.blocks(x)
x = self.classifier(x)
out return out
def create_batch(images_folder):
= []
images for image_file in os.listdir(images_folder):
= kornia.io.load_image(os.path.join(images_folder, image_file))
image = (image.float() / 0.5) - 1 # normalizing to be in (-1, 1) as the generator output
image
images.append(image)= torch.stack(images, dim=0)
batch return batch
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
= Generator(lat_dim=64).to(device)
G = Discriminator().to(device)
D
= torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(D.parameters(), lr=36e-7, betas=(0.5, 0.999))
opt_D = nn.BCEWithLogitsLoss()
criterion
= create_batch("pandas")
real_images = real_images.to(device)
real_images
= torch.ones(real_images.size(0), 1).to(device)
real_labels = torch.zeros(real_images.size(0), 1).to(device)
fake_labels
= torch.randn(4, 64).to(device)
z = G(z)
fake_images print(fake_images.shape)
= D(fake_images)
output print(output.shape)
torch.Size([4, 3, 64, 128])
torch.Size([4, 1])
given the nature of the experiment we’re running: looping on a single batch of size 16
- adjustment speed is set to a slightly high value - p
is updated each step - target real accuracy for the Discriminator to maintain is 0.6 - the exponentially moving average lambda set to 0.5
to have an observable evolution of p
over relatively few steps
= AdaptiveDiscriminatorAugmentation(
ada # you can pass the custom augmentations you want to use here,
=0.05,
adjustment_speed=1,
update_every=0.6,
target_real_acc=0.5,
ema_lambda )
in the training loop we apply ADA on both generated and real images, yet we only pass the current real_acc
value once per step: we update p
once per step (eather with real or generated images)
= 64
lat_dim = 480
epochs
= None
real_acc = []
real_acc_history = []
ada_p_history
= tqdm(range(1, epochs + 1))
pbar for epoch in pbar:
opt_D.zero_grad()
opt_G.zero_grad()
= torch.randperm(real_images.size(0))
perm = real_images[perm].to(device)
real_imgs = ada(real_imgs)
real_imgs
= D(real_imgs)
real_logits = criterion(real_logits, real_labels)
loss_real
# Discriminator step
= torch.randn(real_imgs.size(0), lat_dim).to(device)
noise = G(noise)
fake_imgs = ada(fake_imgs, real_acc=real_acc)
fake_imgs = D(fake_imgs.detach())
fake_logits = criterion(fake_logits, fake_labels)
loss_fake
= loss_real + loss_fake
D_loss
D_loss.backward()
opt_D.step()
= (fake_logits < 0).float().mean().item()
fake_acc = (real_logits > 0).float().mean().item()
real_acc
# Generator step
= D(fake_imgs)
pred_fake = criterion(pred_fake, real_labels)
G_loss
G_loss.backward()
opt_G.step()
real_acc_history.append(real_acc)
ada_p_history.append(ada.p)
pbar.set_postfix(=f"{D_loss.item():.4f}",
D_loss=f"{G_loss.item():.4f}",
G_loss=f"{real_acc:.4f}",
real_acc=f"{fake_acc:.4f}",
fake_acc=ada.p,
p )
100%|██████████| 480/480 [00:48<00:00, 9.95it/s, D_loss=1.3424, G_loss=0.7023, fake_acc=0.6875, p=0.45, real_acc=0.8750]
plotting a smoothed line plot of the real_acc
and p
values over the training steps with a straight line at the target real accuracy
= 32
window = np.arange(len(real_acc_history))
steps = steps[window - 1 :]
smoothed_steps
= np.convolve(real_acc_history, np.ones(window) / window, mode="valid")
smoothed_real_acc = np.convolve(ada_p_history, np.ones(window) / window, mode="valid")
smoothed_ada_p = (
target_acc_line
np.ones_like(
steps,
)* ada.target_real_acc
)
=(16, 8))
plt.figure(figsize="blue", alpha=0.1)
plt.plot(steps, ada_p_history, color="red", alpha=0.1)
plt.plot(steps, real_acc_history, color
="black", label="target real accuracy", linewidth=0.5)
plt.plot(steps, target_acc_line, color="red", label="real accuracy", linewidth=2)
plt.plot(smoothed_steps, smoothed_real_acc, color="blue", label="ADA p", linewidth=2)
plt.plot(smoothed_steps, smoothed_ada_p, color
plt.legend(="lower center",
loc=3,
ncol=(0.5, 0.01),
bbox_to_anchor )
it’s clear to notice that the Discriminator’s accuracy identifying real images hovered around the target value we set, dynamically enforced by image augmentations governed by p