Skip to content

Multi-Modal Guide¤

This guide covers working with multi-modal data in Workshop, including aligned datasets, modality fusion, cross-modal generation, and best practices for multi-modal generative models.

Overview¤

Workshop's multi-modal system enables working with multiple data modalities (image, text, audio) simultaneously, supporting alignment, fusion, and cross-modal generation tasks.

  • Modality Alignment


    Create aligned multi-modal datasets with shared latent representations

  • Paired Datasets


    Handle explicitly paired multi-modal data with alignment scores

  • Modality Fusion


    Combine information from multiple modalities for joint representations

  • Cross-Modal Generation


    Generate one modality from another (e.g., image from text)

  • Shared Latent Space


    Learn unified representations across modalities

  • JAX-Native


    Full JAX compatibility with efficient batch processing

Multi-Modal Datasets¤

Aligned Multi-Modal Dataset¤

Create datasets with aligned samples across modalities:

import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.modalities.multi_modal.datasets import (
    create_synthetic_multi_modal_dataset
)

# Initialize RNG
rngs = nnx.Rngs(0)

# Create aligned multi-modal dataset
dataset = create_synthetic_multi_modal_dataset(
    modalities=["image", "text", "audio"],
    num_samples=1000,
    alignment_strength=0.8,  # 0.0 = random, 1.0 = perfectly aligned
    image_shape=(32, 32, 3),
    text_vocab_size=1000,
    text_sequence_length=50,
    audio_sample_rate=16000,
    audio_duration=1.0,
    rngs=rngs
)

# Access multi-modal sample
sample = dataset[0]
print(sample.keys())
# dict_keys(['image', 'text', 'audio', 'alignment_score', 'latent'])

print(f"Image shape: {sample['image'].shape}")  # (32, 32, 3)
print(f"Text shape: {sample['text'].shape}")  # (50,)
print(f"Audio shape: {sample['audio'].shape}")  # (16000,)
print(f"Alignment: {sample['alignment_score']}")  # 0.8
print(f"Shared latent: {sample['latent'].shape}")  # (32,)

Understanding Alignment Strength¤

The alignment strength controls how strongly modalities are correlated:

# Weakly aligned (more random)
weak_dataset = create_synthetic_multi_modal_dataset(
    modalities=["image", "text"],
    num_samples=1000,
    alignment_strength=0.3,  # 30% alignment
    rngs=rngs
)

# Moderately aligned
moderate_dataset = create_synthetic_multi_modal_dataset(
    modalities=["image", "text"],
    num_samples=1000,
    alignment_strength=0.6,  # 60% alignment
    rngs=rngs
)

# Strongly aligned
strong_dataset = create_synthetic_multi_modal_dataset(
    modalities=["image", "text"],
    num_samples=1000,
    alignment_strength=0.9,  # 90% alignment
    rngs=rngs
)

# Perfect alignment
perfect_dataset = create_synthetic_multi_modal_dataset(
    modalities=["image", "text"],
    num_samples=1000,
    alignment_strength=1.0,  # 100% alignment
    rngs=rngs
)

Paired Multi-Modal Dataset¤

For explicitly paired data:

from workshop.generative_models.modalities.multi_modal.datasets import (
    MultiModalPairedDataset
)

# Prepare paired data
image_data = jnp.array([...])  # (N, H, W, C)
text_data = jnp.array([...])  # (N, max_length)
audio_data = jnp.array([...])  # (N, n_samples)

# Define modality pairs
pairs = [
    ("image", "text"),
    ("image", "audio"),
    ("text", "audio")
]

# Optional alignment scores for each pair
alignments = jnp.ones((len(image_data),))  # All perfectly aligned

# Create paired dataset
paired_dataset = MultiModalPairedDataset(
    pairs=pairs,
    data={
        "image": image_data,
        "text": text_data,
        "audio": audio_data
    },
    alignments=alignments
)

# Access paired sample
sample = paired_dataset[0]
print(sample["image"].shape)  # (H, W, C)
print(sample["text"].shape)  # (max_length,)
print(sample["audio"].shape)  # (n_samples,)
print(sample["alignment_scores"])  # 1.0
print(sample["pairs"])  # [('image', 'text'), ('image', 'audio'), ('text', 'audio')]

Batching Multi-Modal Data¤

# Get batch from aligned dataset
batch = dataset.get_batch(batch_size=32)

print(batch["image"].shape)  # (32, 32, 32, 3)
print(batch["text"].shape)  # (32, 50)
print(batch["audio"].shape)  # (32, 16000)
print(batch["latent"].shape)  # (32, 32)

# Iterate over paired dataset
for i, sample in enumerate(paired_dataset):
    if i >= 3:
        break
    print(f"Sample {i}:")
    for modality in ["image", "text", "audio"]:
        print(f"  {modality}: {sample[modality].shape}")

How Alignment Works¤

The synthetic multi-modal dataset creates aligned data through a shared latent representation:

# Simplified alignment process
def generate_aligned_sample(latent, alignment_strength):
    """Generate aligned multi-modal sample.

    Args:
        latent: Shared latent vector (32,)
        alignment_strength: Strength of alignment (0-1)

    Returns:
        Dictionary with aligned modalities
    """
    # Generate image from latent
    image = generate_image_from_latent(latent, alignment_strength)

    # Generate text from latent
    text = generate_text_from_latent(latent, alignment_strength)

    # Generate audio from latent
    audio = generate_audio_from_latent(latent, alignment_strength)

    return {
        "image": image,
        "text": text,
        "audio": audio,
        "latent": latent,
        "alignment_score": alignment_strength
    }

Image Generation from Latent¤

def generate_image_from_latent(latent, alignment_strength, image_shape=(32, 32, 3)):
    """Generate image from latent representation.

    The latent vector modulates spatial patterns:
    - Higher alignment → stronger influence from latent
    - Lower alignment → more random noise

    Args:
        latent: Shared latent (32,)
        alignment_strength: Alignment factor
        image_shape: Target image shape

    Returns:
        Generated image
    """
    h, w, c = image_shape

    # Create spatial coordinate grids
    x = jnp.linspace(-1, 1, w)
    y = jnp.linspace(-1, 1, h)
    xx, yy = jnp.meshgrid(x, y)

    # Create patterns from latent
    pattern = jnp.zeros((h, w))
    for i in range(min(len(latent), 8)):
        freq = 2 + i
        phase = latent[i] * jnp.pi
        amplitude = jnp.abs(latent[i])
        pattern += amplitude * jnp.sin(freq * xx + phase) * jnp.cos(freq * yy + phase)

    # Normalize pattern
    pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min() + 1e-8)
    pattern = jnp.stack([pattern] * c, axis=-1)

    # Mix with random noise based on alignment
    key = jax.random.key(0)
    noise = jax.random.normal(key, (h, w, c))

    image = alignment_strength * pattern + (1 - alignment_strength) * noise

    # Normalize to [0, 1]
    image = (image - image.min()) / (image.max() - image.min() + 1e-8)

    return image

Text Generation from Latent¤

def generate_text_from_latent(
    latent,
    alignment_strength,
    vocab_size=1000,
    seq_length=50
):
    """Generate text from latent representation.

    The latent vector biases token selection:
    - Higher alignment → stronger bias from latent
    - Lower alignment → more random tokens

    Args:
        latent: Shared latent (32,)
        alignment_strength: Alignment factor
        vocab_size: Vocabulary size
        seq_length: Sequence length

    Returns:
        Generated token sequence
    """
    # Expand latent to vocab size
    latent_expanded = jnp.tile(latent, (vocab_size // len(latent) + 1))
    latent_expanded = latent_expanded[:vocab_size]

    # Create token probabilities from latent
    token_logits = latent_expanded * alignment_strength
    token_probs = jax.nn.softmax(token_logits)

    # Sample tokens
    key = jax.random.key(0)
    tokens = []
    for i in range(seq_length):
        token_key = jax.random.fold_in(key, i)
        token = jax.random.choice(token_key, vocab_size, p=token_probs)
        tokens.append(token)

    return jnp.array(tokens)

Audio Generation from Latent¤

def generate_audio_from_latent(
    latent,
    alignment_strength,
    sample_rate=16000,
    duration=1.0
):
    """Generate audio from latent representation.

    The latent vector controls frequency content:
    - Higher alignment → stronger latent influence
    - Lower alignment → more random noise

    Args:
        latent: Shared latent (32,)
        alignment_strength: Alignment factor
        sample_rate: Sample rate in Hz
        duration: Duration in seconds

    Returns:
        Generated audio waveform
    """
    num_samples = int(sample_rate * duration)
    t = jnp.linspace(0, duration, num_samples)

    # Create audio as sum of sinusoids from latent
    waveform = jnp.zeros(num_samples)

    for i in range(min(len(latent), 10)):
        # Frequency from latent (100-2000 Hz)
        freq = 100 + 1900 * (jnp.abs(latent[i]) % 1)
        phase = latent[i] * 2 * jnp.pi
        amplitude = jnp.abs(latent[i]) * 0.1

        waveform += amplitude * jnp.sin(2 * jnp.pi * freq * t + phase)

    # Add noise based on alignment
    key = jax.random.key(0)
    noise = jax.random.normal(key, (num_samples,)) * 0.1

    waveform = alignment_strength * waveform + (1 - alignment_strength) * noise

    # Normalize
    waveform = waveform / (jnp.max(jnp.abs(waveform)) + 1e-8)

    return waveform

Modality Fusion¤

Combine information from multiple modalities:

Early Fusion¤

Concatenate raw features:

def early_fusion(image, text, audio):
    """Concatenate features from all modalities.

    Args:
        image: Image features (H, W, C)
        text: Text features (seq_length,)
        audio: Audio features (n_samples,)

    Returns:
        Fused features
    """
    # Flatten each modality
    image_flat = image.reshape(-1)
    text_flat = text.reshape(-1)
    audio_flat = audio.reshape(-1)

    # Concatenate
    fused = jnp.concatenate([image_flat, text_flat, audio_flat])

    return fused

# Usage
sample = dataset[0]
fused_features = early_fusion(
    sample["image"],
    sample["text"],
    sample["audio"]
)
print(f"Fused features shape: {fused_features.shape}")

Late Fusion¤

Combine high-level representations:

def late_fusion(image_embedding, text_embedding, audio_embedding):
    """Combine embeddings from separate encoders.

    Args:
        image_embedding: Image encoder output (d_model,)
        text_embedding: Text encoder output (d_model,)
        audio_embedding: Audio encoder output (d_model,)

    Returns:
        Fused embedding
    """
    # Option 1: Concatenation
    fused_concat = jnp.concatenate([
        image_embedding,
        text_embedding,
        audio_embedding
    ])

    # Option 2: Average pooling
    fused_avg = (image_embedding + text_embedding + audio_embedding) / 3

    # Option 3: Weighted sum
    weights = jnp.array([0.4, 0.4, 0.2])  # image, text, audio
    fused_weighted = (
        weights[0] * image_embedding +
        weights[1] * text_embedding +
        weights[2] * audio_embedding
    )

    return fused_weighted

# Usage
# Assuming we have encoders for each modality
# image_emb = image_encoder(sample["image"])
# text_emb = text_encoder(sample["text"])
# audio_emb = audio_encoder(sample["audio"])
# fused = late_fusion(image_emb, text_emb, audio_emb)

Attention-Based Fusion¤

Use attention to weight modalities:

import jax.numpy as jnp
from flax import nnx

class MultiModalAttentionFusion(nnx.Module):
    """Attention-based multi-modal fusion."""

    def __init__(
        self,
        d_model: int,
        num_heads: int = 4,
        *,
        rngs: nnx.Rngs
    ):
        """Initialize attention fusion.

        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
            rngs: Random number generators
        """
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        # Projection layers
        self.query_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.key_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.value_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.out_proj = nnx.Linear(d_model, d_model, rngs=rngs)

    def __call__(
        self,
        image_emb: jax.Array,
        text_emb: jax.Array,
        audio_emb: jax.Array,
        *,
        deterministic: bool = False
    ) -> jax.Array:
        """Fuse modality embeddings with attention.

        Args:
            image_emb: Image embedding (d_model,)
            text_emb: Text embedding (d_model,)
            audio_emb: Audio embedding (d_model,)
            deterministic: Whether in eval mode

        Returns:
            Fused embedding (d_model,)
        """
        # Stack embeddings (3, d_model)
        embeddings = jnp.stack([image_emb, text_emb, audio_emb])

        # Project to Q, K, V
        queries = self.query_proj(embeddings)
        keys = self.key_proj(embeddings)
        values = self.value_proj(embeddings)

        # Compute attention scores
        scores = jnp.matmul(queries, keys.T) / jnp.sqrt(self.d_model)
        attention_weights = jax.nn.softmax(scores, axis=-1)

        # Apply attention
        attended = jnp.matmul(attention_weights, values)

        # Pool across modalities (average)
        fused = jnp.mean(attended, axis=0)

        # Final projection
        output = self.out_proj(fused)

        return output

# Usage
rngs = nnx.Rngs(0)
fusion = MultiModalAttentionFusion(d_model=256, num_heads=4, rngs=rngs)

# image_emb = image_encoder(sample["image"])  # (256,)
# text_emb = text_encoder(sample["text"])  # (256,)
# audio_emb = audio_encoder(sample["audio"])  # (256,)

# fused = fusion(image_emb, text_emb, audio_emb)
# print(f"Fused embedding: {fused.shape}")  # (256,)

Cross-Modal Generation¤

Generate one modality from another:

Image from Text (Text-to-Image)¤

from flax import nnx
import jax.numpy as jnp

class TextToImageGenerator(nnx.Module):
    """Generate images from text."""

    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        image_shape: tuple,
        *,
        rngs: nnx.Rngs
    ):
        """Initialize text-to-image generator.

        Args:
            vocab_size: Text vocabulary size
            embed_dim: Embedding dimension
            image_shape: Target image shape (H, W, C)
            rngs: Random number generators
        """
        super().__init__()
        self.image_shape = image_shape

        # Text encoder
        self.text_embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs)
        self.encoder = nnx.Linear(embed_dim, 512, rngs=rngs)

        # Image decoder
        self.decoder = nnx.Sequential(
            nnx.Linear(512, 1024, rngs=rngs),
            nnx.relu,
            nnx.Linear(1024, int(jnp.prod(jnp.array(image_shape))), rngs=rngs),
            nnx.sigmoid
        )

    def __call__(self, text_tokens: jax.Array) -> jax.Array:
        """Generate image from text.

        Args:
            text_tokens: Text token sequence (seq_length,)

        Returns:
            Generated image (H, W, C)
        """
        # Encode text
        text_emb = self.text_embed(text_tokens)  # (seq_length, embed_dim)
        text_feat = jnp.mean(text_emb, axis=0)  # Pool: (embed_dim,)
        encoded = self.encoder(text_feat)  # (512,)

        # Decode to image
        image_flat = self.decoder(encoded)
        image = image_flat.reshape(self.image_shape)

        return image

# Usage
# generator = TextToImageGenerator(
#     vocab_size=10000,
#     embed_dim=256,
#     image_shape=(32, 32, 3),
#     rngs=rngs
# )
# generated_image = generator(sample["text"])

Text from Image (Image-to-Text)¤

class ImageToTextGenerator(nnx.Module):
    """Generate text from images."""

    def __init__(
        self,
        image_shape: tuple,
        vocab_size: int,
        max_length: int,
        hidden_dim: int = 512,
        *,
        rngs: nnx.Rngs
    ):
        """Initialize image-to-text generator.

        Args:
            image_shape: Input image shape (H, W, C)
            vocab_size: Text vocabulary size
            max_length: Maximum text length
            hidden_dim: Hidden dimension
            rngs: Random number generators
        """
        super().__init__()
        self.max_length = max_length
        self.vocab_size = vocab_size

        # Image encoder
        image_size = int(jnp.prod(jnp.array(image_shape)))
        self.encoder = nnx.Sequential(
            nnx.Linear(image_size, hidden_dim, rngs=rngs),
            nnx.relu,
            nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
        )

        # Text decoder
        self.decoder = nnx.Linear(hidden_dim, vocab_size, rngs=rngs)

    def __call__(self, image: jax.Array, *, rngs: nnx.Rngs | None = None) -> jax.Array:
        """Generate text from image.

        Args:
            image: Input image (H, W, C)
            rngs: Random number generators

        Returns:
            Generated text tokens (max_length,)
        """
        # Encode image
        image_flat = image.reshape(-1)
        encoded = self.encoder(image_flat)  # (hidden_dim,)

        # Decode to text (simplified - sample tokens)
        tokens = []
        for i in range(self.max_length):
            logits = self.decoder(encoded)  # (vocab_size,)

            # Sample token
            if rngs and "sample" in rngs:
                key = rngs.sample()
                token = jax.random.categorical(key, logits)
            else:
                token = jnp.argmax(logits)

            tokens.append(token)

        return jnp.array(tokens)

# Usage
# generator = ImageToTextGenerator(
#     image_shape=(32, 32, 3),
#     vocab_size=10000,
#     max_length=50,
#     rngs=rngs
# )
# generated_text = generator(sample["image"], rngs=rngs)

Audio from Text (Text-to-Speech)¤

class TextToAudioGenerator(nnx.Module):
    """Generate audio from text."""

    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        audio_length: int,
        *,
        rngs: nnx.Rngs
    ):
        """Initialize text-to-audio generator.

        Args:
            vocab_size: Text vocabulary size
            embed_dim: Embedding dimension
            audio_length: Target audio length in samples
            rngs: Random number generators
        """
        super().__init__()
        self.audio_length = audio_length

        # Text encoder
        self.text_embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs)
        self.encoder = nnx.Linear(embed_dim, 512, rngs=rngs)

        # Audio decoder
        self.decoder = nnx.Sequential(
            nnx.Linear(512, 1024, rngs=rngs),
            nnx.relu,
            nnx.Linear(1024, audio_length, rngs=rngs),
            nnx.tanh  # Audio in [-1, 1]
        )

    def __call__(self, text_tokens: jax.Array) -> jax.Array:
        """Generate audio from text.

        Args:
            text_tokens: Text token sequence (seq_length,)

        Returns:
            Generated audio waveform (audio_length,)
        """
        # Encode text
        text_emb = self.text_embed(text_tokens)  # (seq_length, embed_dim)
        text_feat = jnp.mean(text_emb, axis=0)  # Pool
        encoded = self.encoder(text_feat)  # (512,)

        # Decode to audio
        audio = self.decoder(encoded)  # (audio_length,)

        return audio

# Usage
# generator = TextToAudioGenerator(
#     vocab_size=10000,
#     embed_dim=256,
#     audio_length=16000,
#     rngs=rngs
# )
# generated_audio = generator(sample["text"])

Complete Multi-Modal Training Example¤

import jax
import jax.numpy as jnp
from flax import nnx

# Setup
rngs = nnx.Rngs(0)

# Create multi-modal dataset
dataset = create_synthetic_multi_modal_dataset(
    modalities=["image", "text", "audio"],
    num_samples=10000,
    alignment_strength=0.8,
    image_shape=(32, 32, 3),
    text_vocab_size=1000,
    text_sequence_length=50,
    audio_sample_rate=16000,
    audio_duration=1.0,
    rngs=rngs
)

# Create validation dataset
val_dataset = create_synthetic_multi_modal_dataset(
    modalities=["image", "text", "audio"],
    num_samples=1000,
    alignment_strength=0.8,
    image_shape=(32, 32, 3),
    text_vocab_size=1000,
    text_sequence_length=50,
    audio_sample_rate=16000,
    audio_duration=1.0,
    rngs=rngs
)

# Training loop
batch_size = 32
num_epochs = 10
key = jax.random.key(42)

for epoch in range(num_epochs):
    num_batches = len(dataset) // batch_size

    for i in range(num_batches):
        # Get batch
        batch = dataset.get_batch(batch_size)

        # Extract modalities
        images = batch["image"]
        texts = batch["text"]
        audios = batch["audio"]
        latents = batch["latent"]

        # Training step (placeholder)
        # 1. Encode each modality
        # image_emb = image_encoder(images)
        # text_emb = text_encoder(texts)
        # audio_emb = audio_encoder(audios)

        # 2. Compute alignment loss
        # alignment_loss = contrastive_loss(image_emb, text_emb, audio_emb)

        # 3. Compute reconstruction losses
        # recon_loss_img = reconstruction_loss(images, reconstructed_images)
        # recon_loss_text = reconstruction_loss(texts, reconstructed_texts)
        # recon_loss_audio = reconstruction_loss(audios, reconstructed_audios)

        # 4. Total loss
        # loss = alignment_loss + recon_loss_img + recon_loss_text + recon_loss_audio

        # 5. Update parameters
        # params = optimizer.update(grads, params)

    # Validation
    val_batches = len(val_dataset) // batch_size
    for i in range(val_batches):
        val_batch = val_dataset.get_batch(batch_size)
        # Validation step
        # val_loss = validate_step(model, val_batch)

    print(f"Epoch {epoch + 1}/{num_epochs} complete")

Best Practices¤

DO¤

Multi-Modal Design

  • Use shared latent representations for alignment
  • Balance modality contributions in fusion
  • Normalize features before fusion
  • Use attention for dynamic modality weighting
  • Test alignment quality visually/qualitatively
  • Cache aligned datasets when possible

Cross-Modal Generation

  • Use separate encoders for each modality
  • Implement residual connections in decoders
  • Use appropriate loss functions per modality
  • Test generation quality separately per modality
  • Consider cycle consistency for bidirectional generation

Training

  • Use contrastive losses for alignment
  • Balance reconstruction and alignment losses
  • Apply modality-specific augmentation
  • Monitor per-modality metrics
  • Use curriculum learning for complex tasks

DON'T¤

Common Mistakes

  • Mix different alignment strengths in same batch
  • Ignore modality-specific preprocessing
  • Use same architecture for all modalities
  • Apply same augmentation to all modalities
  • Forget to normalize embeddings before fusion
  • Ignore computational cost of attention

Alignment Issues

  • Use too low alignment strength for supervised tasks
  • Mix aligned and unaligned samples
  • Ignore alignment scores during training
  • Use mismatched modality sizes
  • Forget to validate alignment quality

Performance

  • Concatenate raw features from all modalities
  • Use very deep fusion networks
  • Process all modalities even when not needed
  • Ignore modality-specific batch sizes
  • Use attention for simple fusion tasks

Summary¤

This guide covered:

  • Multi-modal datasets - Aligned and paired datasets
  • Alignment - Shared latent representations and alignment strength
  • Modality fusion - Early, late, and attention-based fusion
  • Cross-modal generation - Image↔Text, Text↔Audio
  • Complete example - Multi-modal training pipeline
  • Best practices - DOs and DON'Ts for multi-modal learning

Next Steps¤