Image Modality Guide¤
This guide covers working with image data in Workshop, including image representations, datasets, preprocessing, and best practices for image-based generative models.
Overview¤
Workshop's image modality provides a unified interface for working with different image formats and resolutions. It supports RGB, RGBA, and grayscale images with configurable preprocessing and augmentation.
-
Multiple Representations
Support for RGB, RGBA, and grayscale images with automatic channel handling
-
Flexible Resolutions
Work with any image size from 28x28 to 512x512 and beyond
-
Preprocessing Pipeline
Built-in normalization, resizing, and validation
-
Synthetic Datasets
Ready-to-use synthetic datasets for testing and development
-
Augmentation
Common image augmentation techniques (flip, rotate, brightness, contrast)
-
JAX-Native
Full JAX compatibility with JIT compilation and GPU acceleration
Image Representations¤
Supported Formats¤
Workshop supports three image representations:
from workshop.generative_models.modalities.image.base import ImageRepresentation
# RGB images (3 channels)
ImageRepresentation.RGB
# RGBA images (4 channels with alpha)
ImageRepresentation.RGBA
# Grayscale images (1 channel)
ImageRepresentation.GRAYSCALE
Configuring Image Modality¤
from workshop.generative_models.modalities import ImageModality
from workshop.generative_models.modalities.image.base import (
ImageModalityConfig,
ImageRepresentation
)
from flax import nnx
# Initialize RNG
rngs = nnx.Rngs(0)
# RGB configuration (64x64)
rgb_config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=64,
width=64,
channels=3, # Auto-determined from representation if None
normalize=True, # Normalize to [0, 1]
augmentation=False,
resize_method="bilinear"
)
# Create modality
rgb_modality = ImageModality(config=rgb_config, rngs=rngs)
# RGBA configuration
rgba_config = ImageModalityConfig(
representation=ImageRepresentation.RGBA,
height=128,
width=128,
channels=4 # Alpha channel included
)
rgba_modality = ImageModality(config=rgba_config, rngs=rngs)
# Grayscale configuration
grayscale_config = ImageModalityConfig(
representation=ImageRepresentation.GRAYSCALE,
height=28,
width=28,
channels=1
)
grayscale_modality = ImageModality(config=grayscale_config, rngs=rngs)
Image Shape Properties¤
# Access image dimensions
print(f"Image shape: {rgb_modality.image_shape}") # (64, 64, 3)
print(f"Output shape: {rgb_modality.output_shape}") # (64, 64, 3)
# For MNIST-like
print(f"Grayscale shape: {grayscale_modality.image_shape}") # (28, 28, 1)
Image Datasets¤
Synthetic Image Datasets¤
Workshop provides several synthetic dataset types for testing and development:
Random Patterns¤
from workshop.generative_models.modalities.image.datasets import SyntheticImageDataset
# Random noise patterns
random_dataset = SyntheticImageDataset(
config=rgb_config,
dataset_size=10000,
pattern_type="random",
split="train",
rngs=rngs
)
# Get batch
batch = random_dataset.get_batch(batch_size=32)
print(batch["images"].shape) # (32, 64, 64, 3)
# Each image is filled with uniform random noise
Gradient Patterns¤
# Linear gradients with varying directions
gradient_dataset = SyntheticImageDataset(
config=rgb_config,
dataset_size=10000,
pattern_type="gradient",
split="train",
rngs=rngs
)
# Gradients have:
# - Random directions
# - Smooth color transitions (for RGB)
# - Sinusoidal variations for visual interest
Checkerboard Patterns¤
# Checkerboard patterns with random sizes
checkerboard_dataset = SyntheticImageDataset(
config=rgb_config,
dataset_size=10000,
pattern_type="checkerboard",
split="train",
rngs=rngs
)
# Checkerboards have:
# - Random tile sizes (4-16 pixels)
# - Binary black/white pattern
# - Repeated across color channels
Circular Patterns¤
# Circular patterns with random positions and radii
circles_dataset = SyntheticImageDataset(
config=rgb_config,
dataset_size=10000,
pattern_type="circles",
split="train",
rngs=rngs
)
# Circles have:
# - Random center positions
# - Random radii
# - Gaussian noise for variation
MNIST-Like Datasets¤
For digit-like pattern recognition:
from workshop.generative_models.modalities.image.datasets import MNISTLikeDataset
# Configure for MNIST-like images (28x28 grayscale)
mnist_config = ImageModalityConfig(
representation=ImageRepresentation.GRAYSCALE,
height=28,
width=28,
channels=1,
normalize=True
)
# Create MNIST-like dataset
mnist_dataset = MNISTLikeDataset(
config=mnist_config,
dataset_size=60000,
num_classes=10,
split="train",
rngs=rngs
)
# Get labeled batch
batch = mnist_dataset.get_batch(batch_size=128)
print(batch["images"].shape) # (128, 28, 28, 1)
print(batch["labels"].shape) # (128,)
# Iterate with labels
for sample in mnist_dataset:
image = sample["images"] # (28, 28, 1)
label = sample["labels"] # Scalar label
print(f"Label: {label}, Image shape: {image.shape}")
break
Generated patterns:
- Class 0: Circle (hollow)
- Class 1: Vertical line
- Class 2: Horizontal line
- Additional classes follow similar geometric patterns
Factory Function¤
from workshop.generative_models.modalities.image.datasets import create_image_dataset
# Create dataset using factory
dataset = create_image_dataset(
dataset_type="synthetic", # or "mnist_like"
config=rgb_config,
pattern_type="gradient",
dataset_size=5000,
rngs=rngs
)
# MNIST-like via factory
mnist = create_image_dataset(
dataset_type="mnist_like",
config=mnist_config,
dataset_size=60000,
num_classes=10,
rngs=rngs
)
Image Preprocessing¤
Normalization¤
import jax.numpy as jnp
# Images in [0, 255] → [0, 1]
def normalize_uint8_images(images):
"""Normalize uint8 images to [0, 1]."""
return images.astype(jnp.float32) / 255.0
# Images in [0, 1] → [-1, 1]
def normalize_to_symmetric(images):
"""Normalize to [-1, 1] range."""
return images * 2.0 - 1.0
# Standardization (mean=0, std=1)
def standardize_images(images):
"""Standardize images to zero mean, unit variance."""
mean = jnp.mean(images, axis=(1, 2, 3), keepdims=True)
std = jnp.std(images, axis=(1, 2, 3), keepdims=True)
return (images - mean) / (std + 1e-8)
# Usage
raw_images = jnp.array([...]) # Raw pixel values
normalized = normalize_uint8_images(raw_images)
standardized = standardize_images(normalized)
Resizing¤
import jax
from jax import image as jax_image
def resize_images(images, target_height, target_width, method="bilinear"):
"""Resize images to target dimensions.
Args:
images: Input images (N, H, W, C)
target_height: Target height
target_width: Target width
method: Resize method ("bilinear" or "nearest")
Returns:
Resized images (N, target_height, target_width, C)
"""
batch_size = images.shape[0]
channels = images.shape[3]
if method == "bilinear":
# Use JAX's resize function
resized = jax_image.resize(
images,
shape=(batch_size, target_height, target_width, channels),
method="bilinear"
)
elif method == "nearest":
resized = jax_image.resize(
images,
shape=(batch_size, target_height, target_width, channels),
method="nearest"
)
else:
raise ValueError(f"Unknown resize method: {method}")
return resized
# Usage
images = jnp.array([...]) # (N, 32, 32, 3)
resized = resize_images(images, 64, 64, method="bilinear")
print(resized.shape) # (N, 64, 64, 3)
Using Modality Processor¤
# Create modality with preprocessing
config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=64,
width=64,
normalize=True
)
modality = ImageModality(config=config, rngs=rngs)
# Process raw images
raw_images = jnp.array([...]) # Any shape
processed = modality.process(raw_images)
# Processed images:
# - Resized to (64, 64)
# - Normalized to [0, 1]
# - Batch dimension handled automatically
Image Augmentation¤
Basic Augmentations¤
import jax
import jax.numpy as jnp
def random_horizontal_flip(image, key, prob=0.5):
"""Randomly flip image horizontally.
Args:
image: Input image (H, W, C)
key: Random key
prob: Probability of flipping
Returns:
Flipped or original image
"""
flip = jax.random.bernoulli(key, prob)
return jax.lax.cond(
flip,
lambda img: jnp.flip(img, axis=1),
lambda img: img,
image
)
def random_vertical_flip(image, key, prob=0.5):
"""Randomly flip image vertically."""
flip = jax.random.bernoulli(key, prob)
return jax.lax.cond(
flip,
lambda img: jnp.flip(img, axis=0),
lambda img: img,
image
)
def random_rotation(image, key):
"""Randomly rotate image by 0, 90, 180, or 270 degrees.
Args:
image: Input image (H, W, C)
key: Random key
Returns:
Rotated image
"""
k = jax.random.randint(key, (), 0, 4)
return jnp.rot90(image, k=int(k), axes=(0, 1))
# Usage
key = jax.random.key(0)
keys = jax.random.split(key, 3)
image = jnp.array([...]) # (H, W, C)
image = random_horizontal_flip(image, keys[0])
image = random_vertical_flip(image, keys[1])
image = random_rotation(image, keys[2])
Color Augmentations¤
def random_brightness(image, key, delta=0.2):
"""Randomly adjust brightness.
Args:
image: Input image (H, W, C)
key: Random key
delta: Maximum brightness change
Returns:
Brightness-adjusted image
"""
factor = jax.random.uniform(key, minval=1-delta, maxval=1+delta)
return jnp.clip(image * factor, 0, 1)
def random_contrast(image, key, delta=0.2):
"""Randomly adjust contrast.
Args:
image: Input image (H, W, C)
key: Random key
delta: Maximum contrast change
Returns:
Contrast-adjusted image
"""
factor = jax.random.uniform(key, minval=1-delta, maxval=1+delta)
mean = jnp.mean(image)
return jnp.clip((image - mean) * factor + mean, 0, 1)
def random_saturation(image, key, delta=0.2):
"""Randomly adjust saturation (RGB only).
Args:
image: Input RGB image (H, W, 3)
key: Random key
delta: Maximum saturation change
Returns:
Saturation-adjusted image
"""
factor = jax.random.uniform(key, minval=1-delta, maxval=1+delta)
# Convert to grayscale
gray = jnp.mean(image, axis=-1, keepdims=True)
# Interpolate between gray and original
adjusted = gray + factor * (image - gray)
return jnp.clip(adjusted, 0, 1)
def random_hue(image, key, delta=0.1):
"""Randomly adjust hue (RGB only).
Args:
image: Input RGB image (H, W, 3)
key: Random key
delta: Maximum hue change
Returns:
Hue-adjusted image
"""
factor = jax.random.uniform(key, minval=-delta, maxval=delta)
# Simple hue rotation by channel shifting
r, g, b = image[..., 0], image[..., 1], image[..., 2]
# Rotate through channels
shifted = jnp.stack([
r + factor * (g - r),
g + factor * (b - g),
b + factor * (r - b)
], axis=-1)
return jnp.clip(shifted, 0, 1)
# Usage
key = jax.random.key(0)
keys = jax.random.split(key, 4)
rgb_image = jnp.array([...]) # (H, W, 3)
rgb_image = random_brightness(rgb_image, keys[0])
rgb_image = random_contrast(rgb_image, keys[1])
rgb_image = random_saturation(rgb_image, keys[2])
rgb_image = random_hue(rgb_image, keys[3])
Noise Augmentations¤
def add_gaussian_noise(image, key, std=0.05):
"""Add Gaussian noise to image.
Args:
image: Input image (H, W, C)
key: Random key
std: Standard deviation of noise
Returns:
Noisy image
"""
noise = std * jax.random.normal(key, image.shape)
return jnp.clip(image + noise, 0, 1)
def add_salt_pepper_noise(image, key, prob=0.01):
"""Add salt and pepper noise.
Args:
image: Input image (H, W, C)
key: Random key
prob: Probability of noise per pixel
Returns:
Noisy image
"""
keys = jax.random.split(key, 2)
# Salt (white pixels)
salt_mask = jax.random.bernoulli(keys[0], prob, image.shape)
image = jnp.where(salt_mask, 1.0, image)
# Pepper (black pixels)
pepper_mask = jax.random.bernoulli(keys[1], prob, image.shape)
image = jnp.where(pepper_mask, 0.0, image)
return image
def add_speckle_noise(image, key, std=0.1):
"""Add multiplicative speckle noise.
Args:
image: Input image (H, W, C)
key: Random key
std: Standard deviation of noise
Returns:
Noisy image
"""
noise = 1 + std * jax.random.normal(key, image.shape)
return jnp.clip(image * noise, 0, 1)
# Usage
key = jax.random.key(0)
keys = jax.random.split(key, 3)
image = jnp.array([...]) # (H, W, C)
noisy1 = add_gaussian_noise(image, keys[0], std=0.05)
noisy2 = add_salt_pepper_noise(image, keys[1], prob=0.01)
noisy3 = add_speckle_noise(image, keys[2], std=0.1)
Geometric Augmentations¤
def random_crop(image, key, crop_height, crop_width):
"""Randomly crop image.
Args:
image: Input image (H, W, C)
key: Random key
crop_height: Height of crop
crop_width: Width of crop
Returns:
Cropped image
"""
h, w = image.shape[:2]
# Random starting position
top = jax.random.randint(key, (), 0, h - crop_height + 1)
left = jax.random.randint(jax.random.fold_in(key, 1), (), 0, w - crop_width + 1)
return image[top:top+crop_height, left:left+crop_width]
def center_crop(image, crop_height, crop_width):
"""Center crop image.
Args:
image: Input image (H, W, C)
crop_height: Height of crop
crop_width: Width of crop
Returns:
Center-cropped image
"""
h, w = image.shape[:2]
top = (h - crop_height) // 2
left = (w - crop_width) // 2
return image[top:top+crop_height, left:left+crop_width]
def random_zoom(image, key, zoom_range=(0.8, 1.2)):
"""Randomly zoom image.
Args:
image: Input image (H, W, C)
key: Random key
zoom_range: (min_zoom, max_zoom)
Returns:
Zoomed image
"""
h, w, c = image.shape
zoom_factor = jax.random.uniform(key, minval=zoom_range[0], maxval=zoom_range[1])
# Calculate new size
new_h = int(h * zoom_factor)
new_w = int(w * zoom_factor)
# Resize
from jax import image as jax_image
zoomed = jax_image.resize(
image[jnp.newaxis, ...],
shape=(1, new_h, new_w, c),
method="bilinear"
)[0]
# Crop or pad to original size
if zoom_factor > 1.0:
# Crop
zoomed = center_crop(zoomed, h, w)
else:
# Pad
pad_h = (h - new_h) // 2
pad_w = (w - new_w) // 2
zoomed = jnp.pad(
zoomed,
((pad_h, h - new_h - pad_h), (pad_w, w - new_w - pad_w), (0, 0)),
mode='constant'
)
return zoomed
# Usage
key = jax.random.key(0)
keys = jax.random.split(key, 2)
image = jnp.array([...]) # (64, 64, 3)
cropped = random_crop(image, keys[0], 48, 48)
zoomed = random_zoom(image, keys[1], zoom_range=(0.9, 1.1))
Complete Augmentation Pipeline¤
@jax.jit
def augment_image(image, key):
"""Apply comprehensive augmentation pipeline.
Args:
image: Input image (H, W, C)
key: Random key
Returns:
Augmented image
"""
keys = jax.random.split(key, 8)
# Geometric augmentations
image = random_horizontal_flip(image, keys[0], prob=0.5)
image = random_rotation(image, keys[1])
# Color augmentations
image = random_brightness(image, keys[2], delta=0.2)
image = random_contrast(image, keys[3], delta=0.2)
# RGB-specific
if image.shape[-1] == 3:
image = random_saturation(image, keys[4], delta=0.2)
image = random_hue(image, keys[5], delta=0.1)
# Noise
image = add_gaussian_noise(image, keys[6], std=0.02)
return image
# Batch augmentation
def augment_batch(images, key):
"""Augment batch of images.
Args:
images: Batch of images (N, H, W, C)
key: Random key
Returns:
Augmented batch
"""
batch_size = images.shape[0]
keys = jax.random.split(key, batch_size)
# Vectorize over batch
augmented = jax.vmap(augment_image)(images, keys)
return augmented
# Usage in training
key = jax.random.key(0)
for batch in data_loader:
key, subkey = jax.random.split(key)
augmented_batch = augment_batch(batch["images"], subkey)
# Use augmented_batch for training
Working with Different Image Sizes¤
Common Image Sizes¤
# MNIST-like (28x28 grayscale)
mnist_config = ImageModalityConfig(
representation=ImageRepresentation.GRAYSCALE,
height=28,
width=28
)
# CIFAR-like (32x32 RGB)
cifar_config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=32,
width=32
)
# Standard (64x64 RGB)
standard_config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=64,
width=64
)
# High-res (128x128 RGB)
highres_config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=128,
width=128
)
# Very high-res (256x256 RGB)
veryhighres_config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=256,
width=256
)
Handling Non-Square Images¤
# Wide images (16:9 aspect ratio)
wide_config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=360,
width=640
)
# Portrait images (9:16 aspect ratio)
portrait_config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=640,
width=360
)
# Custom aspect ratio
custom_config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=224,
width=448 # 2:1 aspect ratio
)
Complete Examples¤
Example 1: Training with Augmentation¤
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.modalities import ImageModality
from workshop.generative_models.modalities.image import (
ImageModalityConfig,
ImageRepresentation,
SyntheticImageDataset
)
# Setup
rngs = nnx.Rngs(0)
config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=64,
width=64,
normalize=True
)
modality = ImageModality(config=config, rngs=rngs)
# Create datasets
train_dataset = SyntheticImageDataset(
config=config,
dataset_size=10000,
pattern_type="gradient",
split="train",
rngs=rngs
)
val_dataset = SyntheticImageDataset(
config=config,
dataset_size=1000,
pattern_type="gradient",
split="val",
rngs=rngs
)
# Training loop with augmentation
batch_size = 128
num_epochs = 10
key = jax.random.key(42)
for epoch in range(num_epochs):
# Training
num_batches = len(train_dataset) // batch_size
for i in range(num_batches):
# Get batch
batch = train_dataset.get_batch(batch_size)
# Apply augmentation
key, subkey = jax.random.split(key)
augmented = augment_batch(batch["images"], subkey)
# Training step (placeholder)
# loss = train_step(model, augmented)
# Validation (no augmentation)
val_batches = len(val_dataset) // batch_size
for i in range(val_batches):
val_batch = val_dataset.get_batch(batch_size)
# Validation step
# val_loss = validate_step(model, val_batch["images"])
print(f"Epoch {epoch + 1}/{num_epochs} complete")
Example 2: Multi-Resolution Training¤
# Create datasets at multiple resolutions
resolutions = [32, 64, 128]
datasets = {}
for res in resolutions:
config = ImageModalityConfig(
representation=ImageRepresentation.RGB,
height=res,
width=res
)
datasets[res] = SyntheticImageDataset(
config=config,
dataset_size=5000,
pattern_type="random",
rngs=rngs
)
# Progressive training
for resolution in resolutions:
print(f"Training at {resolution}x{resolution}")
dataset = datasets[resolution]
for epoch in range(5):
for i in range(len(dataset) // 32):
batch = dataset.get_batch(32)
# Train at this resolution
# loss = train_step(model, batch["images"], resolution)
print(f" Epoch {epoch + 1}/5 at {resolution}x{resolution}")
Example 3: Custom Image Dataset¤
from typing import Iterator
from workshop.generative_models.modalities.base import BaseDataset
class CustomImageDataset(BaseDataset):
"""Custom dataset loading images from file paths."""
def __init__(
self,
config: ImageModalityConfig,
image_paths: list[str],
labels: list[int] = None,
split: str = "train",
*,
rngs: nnx.Rngs,
):
super().__init__(config, split, rngs=rngs)
self.image_paths = image_paths
self.labels = labels
# Load and preprocess images
self.images = self._load_images()
def _load_images(self):
"""Load images from paths."""
images = []
for path in self.image_paths:
# In practice, use PIL, OpenCV, etc.
# For demo, generate synthetic
img = jax.random.uniform(
jax.random.key(hash(path)),
(self.config.height, self.config.width, self.config.channels)
)
images.append(img)
return images
def __len__(self) -> int:
return len(self.images)
def __iter__(self) -> Iterator[dict[str, jax.Array]]:
for i, image in enumerate(self.images):
sample = {"images": image, "index": jnp.array(i)}
if self.labels:
sample["labels"] = jnp.array(self.labels[i])
yield sample
def get_batch(self, batch_size: int) -> dict[str, jax.Array]:
key = self.rngs.sample() if "sample" in self.rngs else jax.random.key(0)
indices = jax.random.randint(key, (batch_size,), 0, len(self))
batch_images = [self.images[int(idx)] for idx in indices]
batch = {"images": jnp.stack(batch_images), "indices": indices}
if self.labels:
batch_labels = [self.labels[int(idx)] for idx in indices]
batch["labels"] = jnp.array(batch_labels)
return batch
# Usage
image_paths = ["/path/to/img1.jpg", "/path/to/img2.jpg", ...]
labels = [0, 1, 0, 2, ...] # Optional labels
custom_dataset = CustomImageDataset(
config=config,
image_paths=image_paths,
labels=labels,
rngs=rngs
)
Best Practices¤
DO¤
Image Loading
- Use appropriate image resolution for your task
- Normalize images to [0, 1] or [-1, 1] consistently
- Choose representation that matches your data (RGB vs grayscale)
- Validate image shapes before training
- Cache preprocessed images when possible
- Use synthetic datasets for testing pipelines
Augmentation
- Apply augmentation only during training, not validation
- Use JIT compilation for augmentation pipelines
- Balance augmentation strength with training stability
- Apply geometric augmentations before color augmentations
- Use vectorized operations for batch augmentation
- Test augmentations visually before training
Performance
- Resize images to target resolution once
- Use JAX's native image operations for GPU acceleration
- Batch operations when possible
- Clear image cache periodically for long runs
- Profile image loading to identify bottlenecks
- Consider mixed precision (float16) for memory savings
DON'T¤
Common Mistakes
- Mix different image resolutions in same batch
- Forget to normalize images
- Apply augmentation during validation/testing
- Use non-JAX operations in data pipeline
- Load full-resolution images if working with downscaled versions
- Ignore color space (RGB vs BGR)
- Use excessive augmentation that destroys image structure
Performance Issues
- Load images from disk in training loop
- Use Python loops for image processing
- Apply expensive augmentations without JIT
- Keep multiple copies of images in memory
- Use very large batch sizes on limited GPU memory
Quality Issues
- Over-augment images (too much distortion)
- Use inappropriate resize methods (nearest for photos)
- Mix normalized and unnormalized images
- Ignore aspect ratio when resizing
- Apply same augmentation to all images in batch
Summary¤
This guide covered:
- Image representations - RGB, RGBA, and grayscale configurations
- Image datasets - Synthetic datasets with various patterns
- Preprocessing - Normalization, resizing, and validation
- Augmentation - Geometric, color, and noise augmentations
- Different sizes - Working with various image resolutions
- Complete examples - Training with augmentation, multi-resolution, custom datasets
- Best practices - DOs and DON'Ts for image data
Next Steps¤
-
Learn about text tokenization, vocabulary management, and sequence handling
-
Audio waveform processing, spectrograms, and audio augmentation
-
Working with multiple modalities and aligned multi-modal datasets
-
Complete API documentation for dataset classes and functions