Fixel Algorithms

Image Augmentation#

Notebook by:

Revision History#

Version

Date

User

Content / Changes

1.0.000

02/06/2024

Royi Avital

First version

Open In Colab

# Import Packages

# General Tools
import numpy as np
import scipy as sp
import pandas as pd

# Machine Learning

# Deep Learning
import torch
import torch.nn            as nn
from torch.utils.tensorboard import SummaryWriter
import torchinfo
import torchvision
from torchvision.transforms import v2 as TorchVisionTrns

# Image Processing & Computer Vision
import skimage as ski

# Miscellaneous
import math
import os
from platform import python_version
import random
import time

# Typing
from typing import Any, Callable, Dict, Generator, List, Optional, Self, Set, Tuple, Union

# Visualization
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

# Jupyter
from IPython import get_ipython
from IPython.display import HTML, Image
from IPython.display import display
from ipywidgets import Dropdown, FloatSlider, interact, IntSlider, Layout, SelectionSlider
from ipywidgets import interact

Notations#

  • (?) Question to answer interactively.

  • (!) Simple task to add code for the notebook.

  • (@) Optional / Extra self practice.

  • (#) Note / Useful resource / Food for thought.

Code Notations:

someVar    = 2; #<! Notation for a variable
vVector    = np.random.rand(4) #<! Notation for 1D array
mMatrix    = np.random.rand(4, 3) #<! Notation for 2D array
tTensor    = np.random.rand(4, 3, 2, 3) #<! Notation for nD array (Tensor)
tuTuple    = (1, 2, 3) #<! Notation for a tuple
lList      = [1, 2, 3] #<! Notation for a list
dDict      = {1: 3, 2: 2, 3: 1} #<! Notation for a dictionary
oObj       = MyClass() #<! Notation for an object
dfData     = pd.DataFrame() #<! Notation for a data frame
dsData     = pd.Series() #<! Notation for a series
hObj       = plt.Axes() #<! Notation for an object / handler / function handler

Code Exercise#

  • Single line fill

vallToFill = ???
  • Multi Line to Fill (At least one)

# You need to start writing
????
  • Section to Fill

#===========================Fill This===========================#
# 1. Explanation about what to do.
# !! Remarks to follow / take under consideration.
mX = ???

???
#===============================================================#
# Configuration
# %matplotlib inline

seedNum = 512
np.random.seed(seedNum)
random.seed(seedNum)

# Matplotlib default color palette
lMatPltLibclr = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
# sns.set_theme() #>! Apply SeaBorn theme

runInGoogleColab = 'google.colab' in str(get_ipython())

# Improve performance by benchmarking
torch.backends.cudnn.benchmark = True

# Reproducibility (Per PyTorch Version on the same device)
# torch.manual_seed(seedNum)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark     = False #<! Makes things slower
# Constants

FIG_SIZE_DEF    = (8, 8)
ELM_SIZE_DEF    = 50
CLASS_COLOR     = ('b', 'r')
EDGE_COLOR      = 'k'
MARKER_SIZE_DEF = 10
LINE_WIDTH_DEF  = 2

DATA_FOLDER_PATH    = 'Data'
TENSOR_BOARD_BASE   = 'TB'
# Download Auxiliary Modules for Google Colab
if runInGoogleColab:
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DataManipulation.py
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DataVisualization.py
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DeepLearningPyTorch.py
# Courses Packages
import sys
sys.path.append('../../utils')
from DataVisualization import PlotLabelsHistogram, PlotMnistImages
from DeepLearningPyTorch import NNMode
from DeepLearningPyTorch import RunEpoch
# General Auxiliary Functions

def PlotTransform( lImages: List[torchvision.tv_tensors._image.Image], titleStr: str, bAxis = False ) -> plt.Figure:
    
    numImg = len(lImages)
    axWidh = 3
    
    lWidth  = [lImages[ii].shape[-1] for ii in range(numImg)]
    hF, _ = plt.subplots(nrows = 1, ncols = numImg, figsize = (numImg * axWidh, 5), gridspec_kw = {'width_ratios': lWidth})
    for ii, hA in enumerate(hF.axes):
        mI = torch.permute(lImages[ii], (1, 2, 0))
        hA.imshow(mI, cmap = 'gray')
        hA.set_title(f'{ii}')
        hA.axis('on') if bAxis else hA.axis('off')
    
    hF.suptitle(titleStr)
    
    return hF


def PlotBeta( α: float ) -> None:
    
    vX = np.linspace(0, 1, 1001)
    vP = sp.stats.beta.pdf(vX, α, α)

    hF, hA = plt.subplots(figsize = (8, 6))


    hA.plot (vX, vP, 'b', lw=2)
    hA.set_title(f'Beta($\\alpha={α:0.3f}$, $\\beta={α:0.3f}$)')
    hA.set_ylim([0, 5])
    hA.grid();
    # hF.show()

def PlotAug( λ: Union[float, torch.Tensor], mI1: np.ndarray, mI2: np.ndarray, augStr: str, λVal: float ) -> None:
    mI = λ * mI1 + (1 - λ) * mI2 #<! Supports λ as a mask
    hF, vHa = plt.subplots(nrows = 1, ncols = 2, figsize = (8, 5))

    hA = vHa[0]
    hA.imshow(mI.permute(1, 2, 0))
    hA.set_title(f'{augStr} ($\\lambda = {λVal:0.3f}$)')

    hA = vHa[1]
    hA.stem([0, 1], [λVal, 1 - λVal])
    hA.set_title(f'{augStr} Label ($\\lambda = {λVal:0.3f}$)')
    hA.set_xlabel('Class')
    hA.set_ylabel('Probability')
    hA.set_ylim([0, 1.05])

Image Augmentation - CutOut, MixUp, CutMix#

Several image augmentation techniques have been developed to farther assist the generalization of the models.
Some of the techniques involves manipulation of 2 images and the labels.


Credit: Leonie Monigatti - Cutout, Mixup, and Cutmix: Implementing Modern Image Augmentations in PyTorch.

This notebooks presents:

  • Working torchvision.transforms module.

  • Applying: CutOut, MixUp, CutMix.

This notebook augments both the image data and the labels.


  • (#) PyTorch Tutorial: How to Use CutMix and MixUp.

  • (#) Augmentation can be thought as the set of operation the model should be insensitive to.
    For instance, if it should be insensitive to shift, the same image should be trained on with different shifts.

  • (#) PyTorch Vision is migrating its transforms module from v1 to v2.
    This notebook will focus on v2.

  • (#) While the notebook shows image augmentation in the context of Deep Learning for Computer Vision, the Data Augmentation concept can be utilized for other tasks as well.
    For instance, for Audio Processing on could apply some noise addition, pitch change, filters, etc…

  • (#) The are packages which specialize on image data augmentation: Kornia, Albumentations (Considered to be the fastest), ImgAug (Deprecated), AugLy (Audio, image, text and video).

# Parameters

# Data
imgFile1Url = r'https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/DeepLearningMethods/09_TipsAndTricks/img1.jpg'
imgFile2Url = r'https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/DeepLearningMethods/09_TipsAndTricks/img2.jpg'

img1Label = 0
img2Label = 1

# Model

# Training

# Visualization

Generate / Load Data#

# Load Data

mI1 = ski.io.imread(imgFile1Url)

# Image Dimensions
print(f'Image Dimensions: {mI1.shape[:2]}')
print(f'Image Number of Channels: {mI1.shape[2]}')
print(f'Image Element Type: {mI1.dtype}')
Image Dimensions: (450, 300)
Image Number of Channels: 3
Image Element Type: uint8
# Load Data

mI2 = ski.io.imread(imgFile2Url)

# Image Dimensions
print(f'Image Dimensions: {mI2.shape[:2]}')
print(f'Image Number of Channels: {mI2.shape[2]}')
print(f'Image Element Type: {mI2.dtype}')
Image Dimensions: (183, 275)
Image Number of Channels: 3
Image Element Type: uint8
  • (#) The image is a NumPy array. PyTorch default image loader is using PIL (Pillow, as its optimized version) where the image is the PIL class.

Plot the Data#

# Plot the Data

hF, hA = plt.subplots(figsize = (4, 6))

hA.imshow(mI1)
hA.tick_params(axis = 'both', left = False, top = False, right = False, bottom = False, 
               labelleft = False, labeltop = False, labelright = False, labelbottom = False)
hA.grid(False)
hA.set_title('Input Image 1');
../../../_images/6bf5ebf8169147d15ed7ff4384b5103cf575a6d845142eaa373b8b20b4d6b419.png
# Plot the Data

hF, hA = plt.subplots(figsize = (4, 6))

hA.imshow(mI2)
hA.tick_params(axis = 'both', left = False, top = False, right = False, bottom = False, 
               labelleft = False, labeltop = False, labelright = False, labelbottom = False)
hA.grid(False)
hA.set_title('Input Image 2');
../../../_images/e61c25607375b32dad188c15efaf2822a57d5cfee786cb0f2a55a54b57bbc6fc.png
# Tensor Image (Scaled)

oToImg = TorchVisionTrns.Compose([
    TorchVisionTrns.ToImage(),
    TorchVisionTrns.ToDtype(dtype = torch.float32, scale = True),
])

tI1 = oToImg(mI1)
tI2 = oToImg(mI2)

print(f'Tensor Type: {type(tI1)}')
print(f'Tensor Dimensions: {tI1.shape}')
print(f'Image Element Type: {tI1.dtype}')
print(f'Image Minimum Value: {torch.min(tI1)}')
print(f'Image Maximum Value: {torch.max(tI1)}')
Tensor Type: <class 'torchvision.tv_tensors._image.Image'>
Tensor Dimensions: torch.Size([3, 450, 300])
Image Element Type: torch.float32
Image Minimum Value: 0.0
Image Maximum Value: 1.0

Cut Out (Random Erasing)#

Randomly erases a rectangle on the image.
Assists with regularization of the model with the intuition it works like a “Dropout” layer on the input.

# RandomErasing
oTran = TorchVisionTrns.RandomErasing(p = 1, value = 0)

lTrnImg = [tI1] + [oTran(tI1) for _ in range(6)]
hF = PlotTransform(lTrnImg, 'RandomErasing', True)
../../../_images/e9e382f94b7c222dcca6b02c1dd8d2729437ee390894b187187c6fa4b1a3f385.png
# RandomErasing
oTran = TorchVisionTrns.RandomErasing(p = 1, value = 'random')

lTrnImg = [tI2] + [oTran(tI2) for _ in range(6)]
hF = PlotTransform(lTrnImg, 'RandomErasing', True)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
../../../_images/ae0f57f1098530cb6b5ea410406420905f0adb151d2465df735b47319535d254.png

MixUp#

Samples a parameter \(\lambda\) from a Beta Distribution:

\[ \lambda \sim \text{Beta} \left( \alpha, \alpha \right), \qquad \alpha \in \left[ 0, 1 \right]\]

Using the parameter the image and label is adjusted:

\[\begin{split} \begin{align*} \hat{\boldsymbol{x}} & = \lambda \boldsymbol{x}_{i} + \left( 1 - \lambda \right) \boldsymbol{x}_{j} \\ \hat{\boldsymbol{y}} & = \lambda \boldsymbol{y}_{i} + \left( 1 - \lambda \right) \boldsymbol{y}_{j} \\ \end{align*} \end{split}\]

Where \(\boldsymbol{x}_{i}, \boldsymbol{x}_{j}\) are 2 input vectors and \(\boldsymbol{y}_{i}, \boldsymbol{y}_{j}\) are 2 one hot label encoding.

Beta Distribution#

# Beta Distribution

interact(PlotBeta, α = FloatSlider(min = 0.01, max = 0.99, step = 0.01, value = 0.5, layout = Layout(width = '30%')));
  • (#) For \(\alpha \to 0\) the distribution becomes to a Bernoulli Distribution.

  • (#) For \(\alpha \to 1\) the distribution becomes \(\mathcal{U} \left[ 0 , 1 \right]\).

  • (#) Usually \(\alpha\) is chosen to make the value of \(\lambda\) be \(0\) or \(1\) most probable. Hence \(\alpha\) is relatively small most of the time.

# MixUp
oTran = TorchVisionTrns.Compose([
    TorchVisionTrns.ToImage(),
    TorchVisionTrns.ToDtype(dtype = torch.float32, scale = True),
    TorchVisionTrns.Resize(224),
    TorchVisionTrns.CenterCrop(224),
])

tI1 = oTran(mI1)
tI2 = oTran(mI2)

hPlotMixUp = lambda λ: PlotAug(λ, tI1, tI2, 'MixUp', λ)

interact(hPlotMixUp, λ = FloatSlider(min = 0.0, max = 1.0, step = 0.05, value = 0.0, layout = Layout(width = '30%')));

CutMix#

Samples a parameter \(\lambda\) from a Beta Distribution:

\[ \lambda \sim \text{Beta} \left( \alpha, \alpha \right), \qquad \alpha \in \left[ 0, 1 \right]\]

Using the parameter the image and label is adjusted:

\[\begin{split} \begin{align*} \hat{\boldsymbol{X}} & = \boldsymbol{M} \boldsymbol{X}_{i} + \left( \boldsymbol{1} \boldsymbol{1}^{T} - \boldsymbol{M} \right) \boldsymbol{X}_{j} \\ \hat{\boldsymbol{y}} & = \lambda \boldsymbol{y}_{i} + \left( 1 - \lambda \right) \boldsymbol{y}_{j} \\ \end{align*} \end{split}\]

Where \(\boldsymbol{X}_{i}, \boldsymbol{X}_{j}\) are 2 input images of size \(H \times W\) and \(\boldsymbol{y}_{i}, \boldsymbol{y}_{j}\) are 2 one hot label encoding.
The data mask, \(\boldsymbol{M}\), is built by the bounding box \(\boldsymbol{b} = {\left[ x, y, w, h \right]}^{T}\):

\[ x \sim \mathcal{U} \left[ 0, W \right], y \sim \mathcal{U} \left[ 0, H \right], w = \sqrt{1 - \lambda} W, h = \sqrt{1 - \lambda} H \]
# Generate Random Box

def RandBox( imgW: int, imgH: int, λ: float ) -> Tuple[int, int, int, int]:
    # λ: Proportional to the rectangle size

    xCenter = np.random.randint(imgW)
    yCenter = np.random.randint(imgH)
    
    ratio = np.sqrt (1 - λ)
    w     = np.int32(imgW * ratio)
    h     = np.int32(imgH * ratio)

    xLow  = np.maximum(xCenter - w // 2, 0)
    yLow  = np.maximum(yCenter - h // 2, 0)
    xHigh = np.minimum(xCenter + w // 2, imgW)
    yHigh = np.minimum(yCenter + h // 2, imgH)

    return xLow, yLow, xHigh, yHigh
  • (#) In practice, if the rectangle gets clipped one must rescale \(\lambda\) accordingly.

# Generate Mask

def GenMask( imgW: int, imgH: int, λ: float ) -> torch.Tensor:

    mM = torch.ones((imgH, imgW))
    xLow, yLow, xHigh, yHigh = RandBox(imgW, imgH, λ)

    mM[yLow:yHigh, xLow:xHigh] = 0.0

    return mM
# CutMix

hPlotMixUp = lambda λ: PlotAug(torch.permute(GenMask(224, 224, λ)[:, :, None], (2, 1, 0)), tI1, tI2, 'CutMix', λ)

interact(hPlotMixUp, λ = FloatSlider(min = 0.0, max = 1.0, step = 0.05, value = 0.0, layout = Layout(width = '30%')));