Skip to content

Autoregressive Models User Guide¤

Complete guide to building, training, and using Autoregressive Models with Workshop.

Overview¤

This guide covers practical usage of autoregressive models in Workshop, from basic setup to advanced generation techniques. You'll learn how to:

  • Configure AR Models


    Set up PixelCNN, WaveNet, and Transformer architectures

  • Train Models


    Train with teacher forcing and monitor perplexity

  • Generate Samples


    Sequential generation with various sampling strategies

  • Optimize & Sample


    Tune generation quality with temperature, top-k, and nucleus sampling


Quick Start¤

Basic Transformer Example¤

import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.autoregressive import TransformerAR

# Initialize RNGs
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

# Create Transformer autoregressive model
model = TransformerAR(
    vocab_size=10000,
    sequence_length=512,
    hidden_dim=512,
    num_layers=6,
    num_heads=8,
    rngs=rngs
)

# Training
sequences = jnp.array([[1, 2, 3, 4, 5]])  # [batch, seq_len]
outputs = model(sequences, rngs=rngs)
logits = outputs['logits']  # [batch, seq_len, vocab_size]

print(f"Logits shape: {logits.shape}")

# Generation
samples = model.generate(
    n_samples=4,
    max_length=128,
    temperature=0.8,
    top_p=0.9,
    rngs=rngs
)
print(f"Generated samples shape: {samples.shape}")

Creating Autoregressive Models¤

1. PixelCNN (Image Generation)¤

For autoregressive image generation with masked convolutions:

from workshop.generative_models.models.autoregressive import PixelCNN

# Create PixelCNN for MNIST (28×28 grayscale)
model = PixelCNN(
    image_shape=(28, 28, 1),
    num_layers=7,
    hidden_channels=128,
    num_residual_blocks=5,
    rngs=rngs
)

# Training
images = jnp.zeros((16, 28, 28, 1), dtype=jnp.int32)  # Values in [0, 255]
outputs = model(images, rngs=rngs, training=True)

# Loss
batch = {"images": images}
loss_dict = model.loss_fn(batch, outputs, rngs=rngs)
print(f"Loss: {loss_dict['loss']:.4f}")
print(f"Bits per dim: {loss_dict['bits_per_dim']:.4f}")

# Generation (pixel by pixel)
generated = model.generate(
    n_samples=16,
    temperature=1.0,
    rngs=rngs
)

Key Parameters:

Parameter Default Description
image_shape - (height, width, channels)
num_layers 7 Number of masked conv layers
hidden_channels 128 Hidden layer channels
num_residual_blocks 5 Residual block count

Use Cases:

  • Density estimation on images
  • Lossless image compression
  • Inpainting with spatial conditioning

2. WaveNet (Audio Generation)¤

For raw audio waveform modeling:

from workshop.generative_models.models.autoregressive import WaveNet

# Create WaveNet
model = WaveNet(
    num_layers=30,
    num_stacks=3,
    residual_channels=128,
    dilation_channels=256,
    skip_channels=512,
    rngs=rngs
)

# Training
waveform = jnp.zeros((4, 16000), dtype=jnp.int32)  # 1 second at 16kHz
outputs = model(waveform, rngs=rngs)

# Loss
batch = {"waveform": waveform}
loss_dict = model.loss_fn(batch, outputs, rngs=rngs)

# Generation
generated_audio = model.generate(
    n_samples=1,
    max_length=16000,  # 1 second
    temperature=0.9,
    rngs=rngs
)

Key Parameters:

Parameter Default Description
num_layers 30 Total dilated conv layers
num_stacks 3 Number of dilation stacks
residual_channels 128 Residual connection channels
dilation_channels 256 Dilated conv channels

Dilation Pattern:

# WaveNet uses exponentially increasing dilations
# Stack 1: dilations = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
# Stack 2: dilations = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
# Stack 3: dilations = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
# Receptive field = 1024 time steps

3. Transformer (Sequence Modeling)¤

For text, code, and general sequences:

from workshop.generative_models.models.autoregressive import TransformerAR

# Standard Transformer configuration
model = TransformerAR(
    vocab_size=50000,         # Vocabulary size
    sequence_length=1024,     # Maximum sequence length
    hidden_dim=768,           # Model dimension
    num_layers=12,            # Transformer blocks
    num_heads=12,             # Attention heads
    feedforward_dim=3072,     # FFN hidden dimension
    dropout_rate=0.1,         # Dropout probability
    rngs=rngs
)

# Training
sequences = jnp.zeros((8, 512), dtype=jnp.int32)  # [batch, seq_len]
outputs = model(sequences, rngs=rngs, training=True)
logits = outputs['logits']

# Compute loss
batch = {"sequences": sequences}
loss_dict = model.loss_fn(batch, outputs, rngs=rngs)
print(f"NLL Loss: {loss_dict['nll_loss']:.4f}")
print(f"Perplexity: {loss_dict['perplexity']:.2f}")
print(f"Accuracy: {loss_dict['accuracy']:.4f}")

Architecture Scaling:

# Small (for experiments)
small_config = {
    "hidden_dim": 256,
    "num_layers": 4,
    "num_heads": 4,
    "feedforward_dim": 1024,
}

# Medium (GPT-2 small)
medium_config = {
    "hidden_dim": 768,
    "num_layers": 12,
    "num_heads": 12,
    "feedforward_dim": 3072,
}

# Large (GPT-2 medium)
large_config = {
    "hidden_dim": 1024,
    "num_layers": 24,
    "num_heads": 16,
    "feedforward_dim": 4096,
}

Training Autoregressive Models¤

Teacher Forcing Training¤

Standard training uses ground truth previous tokens:

def train_step(model, batch, optimizer_state):
    """Standard teacher forcing training step."""

    def loss_fn(model):
        # Forward pass with ground truth input
        outputs = model(batch['sequences'], training=True, rngs=rngs)

        # Compute loss
        loss_dict = model.loss_fn(batch, outputs, rngs=rngs)

        return loss_dict['loss'], loss_dict

    # Compute gradients
    (loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

    # Update parameters
    optimizer_state = optimizer.update(grads, optimizer_state)

    return loss, metrics, optimizer_state

# Training loop
for epoch in range(num_epochs):
    for batch in train_loader:
        loss, metrics, optimizer_state = train_step(
            model, batch, optimizer_state
        )

Monitoring Training¤

Track key metrics:

def train_with_monitoring(model, train_loader, val_loader, num_epochs):
    """Training with detailed monitoring."""

    for epoch in range(num_epochs):
        # Training
        train_losses = []
        train_perplexities = []

        for step, batch in enumerate(train_loader):
            outputs = model(batch['sequences'], training=True, rngs=rngs)
            loss_dict = model.loss_fn(batch, outputs, rngs=rngs)

            train_losses.append(loss_dict['loss'])
            train_perplexities.append(loss_dict['perplexity'])

            if step % 100 == 0:
                print(f"Epoch {epoch}, Step {step}:")
                print(f"  Loss: {loss_dict['loss']:.4f}")
                print(f"  Perplexity: {loss_dict['perplexity']:.2f}")
                print(f"  Accuracy: {loss_dict['accuracy']:.4f}")

        # Validation
        if val_loader is not None:
            val_loss, val_ppl = evaluate(model, val_loader)
            print(f"\nEpoch {epoch} Validation:")
            print(f"  Loss: {val_loss:.4f}")
            print(f"  Perplexity: {val_ppl:.2f}")

def evaluate(model, val_loader):
    """Evaluate on validation set."""
    total_loss = 0
    total_tokens = 0

    for batch in val_loader:
        outputs = model(batch['sequences'], training=False, rngs=rngs)
        loss_dict = model.loss_fn(batch, outputs, rngs=rngs)

        batch_size, seq_len = batch['sequences'].shape
        total_loss += loss_dict['loss'] * batch_size * seq_len
        total_tokens += batch_size * seq_len

    avg_loss = total_loss / total_tokens
    perplexity = jnp.exp(avg_loss)

    return avg_loss, perplexity

Learning Rate Scheduling¤

Transformers benefit from learning rate warmup:

def transformer_lr_schedule(step, warmup_steps=4000, d_model=512):
    """Transformer learning rate schedule with warmup."""
    step = jnp.maximum(step, 1)
    arg1 = step ** -0.5
    arg2 = step * (warmup_steps ** -1.5)
    return (d_model ** -0.5) * jnp.minimum(arg1, arg2)

# Apply schedule
lr = transformer_lr_schedule(current_step, warmup_steps=4000, d_model=768)

Generation and Sampling¤

1. Greedy Decoding¤

Most likely token at each step:

# Greedy generation
samples = model.generate(
    n_samples=4,
    max_length=128,
    temperature=1.0,  # No effect with greedy (argmax)
    rngs=rngs
)

2. Temperature Sampling¤

Control randomness:

# Low temperature (more deterministic)
deterministic_samples = model.generate(
    n_samples=4,
    max_length=128,
    temperature=0.5,  # More peaked distribution
    rngs=rngs
)

# High temperature (more random)
random_samples = model.generate(
    n_samples=4,
    max_length=128,
    temperature=1.5,  # Flatter distribution
    rngs=rngs
)

Temperature Guidelines:

  • 0.5: Very deterministic, repetitive
  • 0.7: Slightly creative, coherent
  • 1.0: Sample from true model distribution
  • 1.2: More diverse, less coherent
  • 1.5+: Very random, often incoherent

3. Top-k Sampling¤

Sample from k most likely tokens:

# Top-k sampling
samples = model.generate(
    n_samples=4,
    max_length=128,
    temperature=1.0,
    top_k=40,  # Only consider top 40 tokens
    rngs=rngs
)

Top-k Values:

  • k=1: Greedy (deterministic)
  • k=10: Very focused
  • k=40: Balanced (common for text)
  • k=100: More diverse

4. Top-p (Nucleus) Sampling¤

Sample from smallest set with cumulative probability ≥ p:

# Top-p (nucleus) sampling
samples = model.generate(
    n_samples=4,
    max_length=128,
    temperature=1.0,
    top_p=0.9,  # Nucleus with 90% probability mass
    rngs=rngs
)

Top-p Values:

  • p=0.5: Very focused
  • p=0.7: Focused but creative
  • p=0.9: Balanced (recommended)
  • p=0.95: More diverse
  • p=1.0: No filtering

Maintain multiple hypotheses:

# Beam search
samples = model.beam_search(
    beam_size=5,
    max_length=128,
    rngs=rngs
)

Beam Search Use Cases:

  • Machine translation
  • Summarization
  • When likelihood is more important than diversity

6. Combined Strategies¤

Combine multiple techniques:

# Recommended: temperature + top-p
samples = model.generate(
    n_samples=4,
    max_length=128,
    temperature=0.8,   # Slight sharpening
    top_p=0.9,         # Nucleus sampling
    rngs=rngs
)

# Alternative: temperature + top-k
samples = model.generate(
    n_samples=4,
    max_length=128,
    temperature=0.7,
    top_k=50,
    rngs=rngs
)

Conditional Generation¤

1. Prompt-Based Generation¤

Generate from a prefix:

# Start with a prompt
prompt = jnp.array([[1, 45, 23, 89]])  # Token IDs

# Continue from prompt
continuation = model.sample_with_conditioning(
    conditioning=prompt,
    n_samples=4,  # 4 completions for same prompt
    temperature=0.8,
    top_p=0.9,
    rngs=rngs
)

print(f"Prompt length: {prompt.shape[1]}")
print(f"Continuation shape: {continuation.shape}")

2. Class-Conditional Generation (PixelCNN)¤

For class-conditional image generation:

# Add class conditioning to PixelCNN
class ConditionalPixelCNN(PixelCNN):
    def __init__(self, *args, num_classes=10, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_classes = num_classes
        self.class_embedding = nnx.Embed(
            num_classes, self.hidden_channels, rngs=kwargs['rngs']
        )

# Generate specific class
class_label = 7
conditional_images = model.generate_conditional(
    class_label=class_label,
    n_samples=16,
    rngs=rngs
)

3. Inpainting (PixelCNN)¤

Spatial conditioning for inpainting:

# Conditioning image with mask
conditioning = jnp.zeros((28, 28, 1), dtype=jnp.int32)
mask = jnp.zeros((28, 28))  # 0 = generate, 1 = keep

# Set known pixels
conditioning = conditioning.at[0:10, 0:10, :].set(known_values)
mask = mask.at[0:10, 0:10].set(1)

# Inpaint
inpainted = model.inpaint(
    conditioning=conditioning,
    mask=mask,
    n_samples=4,
    temperature=1.0,
    rngs=rngs
)

Advanced Techniques¤

1. Caching for Faster Generation¤

Cache key-value pairs for Transformers:

class TransformerWithCache(TransformerAR):
    """Transformer with KV cache for faster generation."""

    def generate_with_cache(self, n_samples, max_length, rngs):
        """Generate using cached key-value pairs."""
        sequences = jnp.zeros((n_samples, max_length), dtype=jnp.int32)
        cache = None

        for pos in range(max_length):
            # Only compute new position (reuse cache)
            outputs, cache = self.forward_with_cache(
                sequences[:, :pos+1], cache=cache, rngs=rngs
            )
            logits = outputs['logits'][:, pos, :]

            # Sample next token
            next_tokens = jax.random.categorical(
                rngs.sample(), logits, axis=-1
            )
            sequences = sequences.at[:, pos].set(next_tokens)

        return sequences

2. Speculative Sampling¤

Speed up generation with a draft model:

def speculative_sampling(target_model, draft_model, n_samples, max_length):
    """Faster sampling using a smaller draft model."""
    sequences = jnp.zeros((n_samples, max_length), dtype=jnp.int32)

    pos = 0
    while pos < max_length:
        # Draft model generates k tokens quickly
        k = 5
        draft_tokens = draft_model.generate(
            conditioning=sequences[:, :pos],
            n_samples=k,
            rngs=rngs
        )

        # Target model verifies
        target_outputs = target_model(draft_tokens, rngs=rngs)
        target_probs = nnx.softmax(target_outputs['logits'], axis=-1)

        # Accept or reject based on probability ratios
        # ... acceptance logic ...

        pos += accepted_tokens

    return sequences

3. Prefix Tuning for Adaptation¤

Adapt to new tasks with prefix tuning:

class PrefixTunedTransformer(TransformerAR):
    """Transformer with learnable prefix for task adaptation."""

    def __init__(self, *args, prefix_length=10, **kwargs):
        super().__init__(*args, **kwargs)
        self.prefix_length = prefix_length

        # Learnable prefix embeddings
        self.prefix_embeddings = nnx.Param(
            jax.random.normal(
                kwargs['rngs'].params(),
                (prefix_length, self.hidden_dim)
            )
        )

    def forward_with_prefix(self, x, rngs):
        """Forward pass with prefix prepended."""
        batch_size = x.shape[0]

        # Expand prefix for batch
        prefix = jnp.tile(self.prefix_embeddings[None], (batch_size, 1, 1))

        # Embed input
        x_embedded = self.embedding(x)

        # Concatenate prefix and input
        x_with_prefix = jnp.concatenate([prefix, x_embedded], axis=1)

        # Forward through Transformer
        outputs = self.transformer(x_with_prefix, rngs=rngs)

        return outputs

Troubleshooting¤

Common Issues and Solutions¤

  • High Perplexity


    Symptoms: Perplexity stays high, poor generation

    Solutions: - Increase model capacity - More training epochs - Better data preprocessing - Check for label smoothing

    # Increase model size
    config["hidden_dim"] = 768  # from 512
    config["num_layers"] = 12  # from 6
    
  • Slow Generation


    Symptoms: Sequential generation takes too long

    Solutions: - Use KV caching (Transformers) - Reduce sequence length - Use smaller model for drafting - JIT compile generation

    @jax.jit
    def fast_generate(model, n_samples, max_length):
        return model.generate(n_samples, max_length, rngs=rngs)
    
  • Repetitive Output


    Symptoms: Model generates same tokens repeatedly

    Solutions: - Increase temperature - Use nucleus (top-p) sampling - Add repetition penalty - More diverse training data

    # More diverse sampling
    samples = model.generate(
        temperature=1.2,  # Higher than 1.0
        top_p=0.95,       # Wider nucleus
        rngs=rngs
    )
    
  • Training Instability


    Symptoms: Loss spikes, NaN gradients

    Solutions: - Lower learning rate - Add gradient clipping - Use warmup schedule - Check data preprocessing

    # Gradient clipping
    grads = jax.tree_map(
        lambda g: jnp.clip(g, -1.0, 1.0), grads
    )
    

Best Practices¤

1. Data Preprocessing¤

def preprocess_text(text, tokenizer):
    """Proper text preprocessing."""
    # Tokenize
    tokens = tokenizer.encode(text)

    # Add special tokens
    tokens = [tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]

    # Pad/truncate to fixed length
    max_length = 512
    if len(tokens) < max_length:
        tokens = tokens + [tokenizer.pad_token_id] * (max_length - len(tokens))
    else:
        tokens = tokens[:max_length]

    return jnp.array(tokens)

def preprocess_image(image):
    """Proper image preprocessing for PixelCNN."""
    # Ensure uint8 values [0, 255]
    image = jnp.clip(image, 0, 255).astype(jnp.uint8)
    return image

2. Start with Small Models¤

# Quick iteration with small model
small_model = TransformerAR(
    vocab_size=10000,
    sequence_length=128,  # Short sequences
    hidden_dim=256,       # Small dimension
    num_layers=4,         # Few layers
    num_heads=4,
    rngs=rngs
)

# Train quickly, verify everything works
# Then scale up

3. Monitor Generation Quality¤

def monitor_generation_quality(model, val_prompts, epoch):
    """Regularly check generation quality."""
    print(f"\nEpoch {epoch} - Generation Samples:")

    for i, prompt in enumerate(val_prompts[:3]):
        # Generate
        completion = model.sample_with_conditioning(
            conditioning=prompt,
            temperature=0.8,
            top_p=0.9,
            rngs=rngs
        )

        # Decode and display
        text = tokenizer.decode(completion[0])
        print(f"\nPrompt {i+1}: {tokenizer.decode(prompt[0])}")
        print(f"Completion: {text}")

4. Use Appropriate Metrics¤

# Track multiple metrics
metrics = {
    "nll_loss": [],     # Negative log-likelihood
    "perplexity": [],   # exp(nll_loss)
    "accuracy": [],     # Token-level accuracy
    "bpd": [],          # Bits per dimension (for images)
}

# For text generation
def evaluate_text_generation(model, generated_samples):
    """Evaluate generation quality."""
    return {
        "diversity": compute_diversity(generated_samples),
        "coherence": compute_coherence(generated_samples),
        "fluency": compute_fluency(generated_samples),
    }

Example: Complete Text Generation¤

from workshop.generative_models.models.autoregressive import TransformerAR
import tensorflow_datasets as tfds

# Load dataset (e.g., WikiText)
train_ds = tfds.load('wiki40b/en', split='train')

# Create model
model = TransformerAR(
    vocab_size=50000,
    sequence_length=512,
    hidden_dim=768,
    num_layers=12,
    num_heads=12,
    feedforward_dim=3072,
    dropout_rate=0.1,
    rngs=rngs
)

# Training configuration
learning_rate = 1e-4
num_epochs = 10
batch_size = 32

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    for step, batch in enumerate(train_ds.batch(batch_size)):
        # Preprocess
        sequences = preprocess_batch(batch)

        # Forward pass
        outputs = model(sequences, training=True, rngs=rngs)

        # Compute loss
        loss_dict = model.loss_fn(
            {"sequences": sequences}, outputs, rngs=rngs
        )

        # Backward pass (via optimizer)
        # ... update parameters ...

        if step % 100 == 0:
            print(f"  Step {step}: Loss={loss_dict['nll_loss']:.4f}, "
                  f"PPL={loss_dict['perplexity']:.2f}")

    # Generate samples
    prompt = "The quick brown fox"
    prompt_tokens = tokenizer.encode(prompt)
    completion = model.sample_with_conditioning(
        conditioning=jnp.array([prompt_tokens]),
        temperature=0.8,
        top_p=0.9,
        rngs=rngs
    )
    print(f"\nGeneration: {tokenizer.decode(completion[0])}")

print("Training complete!")

Performance Optimization¤

GPU Utilization¤

# Move to GPU
from workshop.generative_models.core.device_manager import DeviceManager

device_manager = DeviceManager()
device = device_manager.get_device()

# Move model and data to GPU
model = jax.device_put(model, device)
batch = jax.device_put(batch, device)

Batch Size Tuning¤

# Larger batches for better GPU utilization
# But: limited by memory

# PixelCNN (memory intensive)
pixelcnn_batch_size = 32

# Transformer (depends on sequence length)
transformer_batch_sizes = {
    128: 256,   # Short sequences
    512: 64,    # Medium sequences
    1024: 16,   # Long sequences
}

# WaveNet (very memory intensive)
wavenet_batch_size = 4

Mixed Precision Training¤

# Use bfloat16 for faster training
from jax import config
config.update("jax_enable_x64", False)

# Model automatically uses bfloat16 on TPU

Further Reading¤


Summary¤

Key Takeaways:

  • Autoregressive models factorize probability via chain rule: \(p(x) = \prod_i p(x_i | x_{<i})\)
  • Training uses teacher forcing with cross-entropy loss
  • Generation is sequential, one token at a time
  • Sampling strategies (temperature, top-k, top-p) control diversity vs quality
  • PixelCNN for images, WaveNet for audio, Transformers for text

Recommended Workflow:

  1. Choose architecture based on data type (PixelCNN/WaveNet/Transformer)
  2. Start with small model for quick iteration
  3. Train with teacher forcing, monitor perplexity
  4. Generate samples with temperature=0.8, top_p=0.9
  5. Scale up model size for better quality
  6. Use caching and JIT for faster inference

For theoretical understanding, see the Autoregressive Explained guide.