%%capture
%matplotlib inline
# Install latest kornia
!pip install kornia
!pip install kornia-rsExtracting and Combining Tensor Patches
Basic
Patches
kornia.contrib
In this tutorial we will show how you can extract and combine tensor patches using kornia
Install and get data
import io
import requests
def download_image(url: str, filename: str = "") -> str:
filename = url.split("/")[-1] if len(filename) == 0 else filename
# Download
bytesio = io.BytesIO(requests.get(url).content)
# Save file
with open(filename, "wb") as outfile:
outfile.write(bytesio.getbuffer())
return filename
url = "https://raw.githubusercontent.com/kornia/data/main/panda.jpg"
download_image(url)'panda.jpg'
import kornia as K
import matplotlib.pyplot as plt
import torch
from kornia.contrib import (
CombineTensorPatches,
ExtractTensorPatches,
combine_tensor_patches,
compute_padding,
extract_tensor_patches,
)Using Modules
h, w = 8, 8
win = 4
pad = 2
image = torch.randn(2, 3, h, w)
print(image.shape)
tiler = ExtractTensorPatches(window_size=win, stride=win, padding=pad)
merger = CombineTensorPatches(original_size=(h, w), window_size=win, stride=win, unpadding=pad)
image_tiles = tiler(image)
print(image_tiles.shape)
new_image = merger(image_tiles)
print(new_image.shape)
assert torch.allclose(image, new_image)torch.Size([2, 3, 8, 8])
torch.Size([2, 9, 3, 4, 4])
torch.Size([2, 3, 8, 8])
Using Functions
h, w = 8, 8
win = 4
pad = 2
image = torch.randn(1, 1, h, w)
print(image.shape)
patches = extract_tensor_patches(image, window_size=win, stride=win, padding=pad)
print(patches.shape)
restored_img = combine_tensor_patches(patches, original_size=(h, w), window_size=win, stride=win, unpadding=pad)
print(restored_img.shape)
assert torch.allclose(image, restored_img)torch.Size([1, 1, 8, 8])
torch.Size([1, 9, 1, 4, 4])
torch.Size([1, 1, 8, 8])
Padding
All parameters of extract and combine functions accept a single int or tuple of two ints. Since padding is an integral part of these functions, it’s important to note the following:
- If padding is
p-> it means both height and width are padded by2*p - If padding is
(ph, pw)-> it means height is padded by2*phand width is padded by2*pw
It is recommended to use the existing function compute_padding to ensure the required padding is added.
Examples
def extract_and_combine(image, window_size, padding):
h, w = image.shape[-2:]
tiler = ExtractTensorPatches(window_size=window_size, stride=window_size, padding=padding)
merger = CombineTensorPatches(original_size=(h, w), window_size=window_size, stride=window_size, unpadding=padding)
image_tiles = tiler(image)
print(f"Shape of tensor patches = {image_tiles.shape}")
merged_image = merger(image_tiles)
print(f"Shape of merged image = {merged_image.shape}")
assert torch.allclose(image, merged_image)
return merged_imageimage = torch.randn(2, 3, 9, 9)
_ = extract_and_combine(image, window_size=(4, 4), padding=(2, 2))Shape of tensor patches = torch.Size([2, 9, 3, 4, 4])
Shape of merged image = torch.Size([2, 3, 9, 9])
These functions also work with rectangular images
rect_image = torch.randn(1, 1, 8, 6)
print(rect_image.shape)torch.Size([1, 1, 8, 6])
restored_image = extract_and_combine(rect_image, window_size=(4, 4), padding=compute_padding((8, 6), 4))Shape of tensor patches = torch.Size([1, 4, 1, 4, 4])
Shape of merged image = torch.Size([1, 1, 8, 6])
Recall that when padding is a tuple of ints (ph, pw), the height and width are padded by 2*ph and 2*pw respectively.
# Confirm that the original image and restored image are the same
assert (restored_image == rect_image).all()Let’s now visualize how extraction and combining works.
# Load sample image
img_tensor = K.io.load_image("panda.jpg", K.io.ImageLoadType.RGB32)[None, ...] # BxCxHxW
h, w = img_tensor.shape[-2:]
print(f"Shape of image = {img_tensor.shape}")
plt.axis("off")
plt.imshow(K.tensor_to_image(img_tensor))
plt.show()Shape of image = torch.Size([1, 3, 510, 1020])

We will use window_size = (400, 400) with stride = 200 to extract 15 overlapping tiles of shape (400, 400) and visualize them.
# Set window size
win = 400
# Set stride
stride = 200
# Calculate required padding
pad = compute_padding(original_size=(510, 1020), window_size=win)
tiler = ExtractTensorPatches(window_size=win, stride=stride, padding=pad)
image_tiles = tiler(img_tensor)
print(f"Shape of image tiles = {image_tiles.shape}")Shape of image tiles = torch.Size([1, 15, 3, 400, 400])
# Create the plot
fig, axs = plt.subplots(5, 3, figsize=(8, 8))
axs = axs.ravel()
for i in range(len(image_tiles[0])):
axs[i].axis("off")
axs[i].imshow(K.tensor_to_image(image_tiles[0][i]))
plt.show()
Finally, let’s combine the patches and visualize the resulting image
merger = CombineTensorPatches(original_size=(h, w), window_size=win, stride=stride, unpadding=pad)
merged_image = merger(image_tiles)
print(f"Shape of restored image = {merged_image.shape}")
plt.imshow(K.tensor_to_image(merged_image[0]))
plt.axis("off")
plt.show()Shape of restored image = torch.Size([1, 3, 510, 1020])
