Inference Optimization¤
Advanced techniques for optimizing generative model inference: JIT compilation, quantization, memory management, and hardware-specific optimizations.
Overview¤
Inference optimization is critical for deploying generative models in production. This guide covers techniques to maximize throughput, minimize latency, and reduce memory usage while maintaining generation quality.
Performance Gains
- JIT Compilation: 10-100x speedup on first-time compilation
- Mixed Precision: 2-4x throughput improvement
- Quantization: 2-4x memory reduction
- Batching: Near-linear scaling with batch size
-
JIT Compilation
Compile JAX functions for dramatic speedups
-
Memory Optimization
Reduce memory footprint for larger models
-
Quantization
Use lower precision for faster inference
-
Hardware Tuning
Optimize for GPUs, TPUs, and multi-device setups
Prerequisites¤
from flax import nnx
import jax
import jax.numpy as jnp
from workshop.generative_models.core import DeviceManager
JIT Compilation¤
Basic JIT Usage¤
JAX's Just-In-Time compilation dramatically improves performance.
class JITOptimizedInference:
"""Inference with JIT compilation."""
def __init__(self, model):
self.model = model
# Compile inference function
self.generate_jit = jax.jit(self._generate_impl)
# Warmup compilation
self._warmup()
def _generate_impl(self, z: jax.Array) -> jax.Array:
"""Implementation to be JIT-compiled."""
return self.model.decode(z)
def _warmup(self):
"""Warmup JIT compilation with dummy input."""
dummy_z = jnp.zeros((1, self.model.latent_dim))
_ = self.generate_jit(dummy_z)
print("JIT compilation complete")
def generate(self, z: jax.Array) -> jax.Array:
"""Generate samples (uses compiled function)."""
return self.generate_jit(z)
Static vs Dynamic Shapes¤
JIT works best with static shapes. Handle dynamic shapes carefully.
def create_shape_specific_functions(model, batch_sizes: list[int]):
"""Compile separate functions for each batch size.
Args:
model: Generative model
batch_sizes: List of expected batch sizes
Returns:
Dictionary of compiled functions
"""
compiled_fns = {}
for batch_size in batch_sizes:
@jax.jit
def generate_fn(z):
return model.decode(z)
# Warmup with specific shape
dummy_z = jnp.zeros((batch_size, model.latent_dim))
_ = generate_fn(dummy_z)
compiled_fns[batch_size] = generate_fn
return compiled_fns
# Usage
compiled_fns = create_shape_specific_functions(model, [1, 4, 16, 32])
# Use appropriate function
batch_size = 16
z = jax.random.normal(key, (batch_size, latent_dim))
samples = compiled_fns[batch_size](z)
Compilation Cache Management¤
import os
from jax.experimental.compilation_cache import compilation_cache
class CachedInference:
"""Inference with persistent compilation cache."""
def __init__(self, model, cache_dir: str = "/tmp/jax_cache"):
self.model = model
# Enable compilation caching
os.makedirs(cache_dir, exist_ok=True)
compilation_cache.set_cache_dir(cache_dir)
# Compile function (cache persists across runs)
self.generate = jax.jit(model.decode)
# First run compiles or loads from cache
self._warmup()
def _warmup(self):
dummy_z = jnp.zeros((1, self.model.latent_dim))
_ = self.generate(dummy_z)
Quantization¤
Mixed Precision Inference¤
Use FP16 for faster inference on modern GPUs.
class MixedPrecisionModel(nnx.Module):
"""Model with mixed precision inference."""
def __init__(
self,
base_model,
*,
compute_dtype: jnp.dtype = jnp.float16,
param_dtype: jnp.dtype = jnp.float32,
):
super().__init__()
self.base_model = base_model
self.compute_dtype = compute_dtype
self.param_dtype = param_dtype
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with mixed precision."""
# Convert input to compute dtype
x = x.astype(self.compute_dtype)
# Run model (parameters stay in param_dtype)
output = self.base_model(x)
# Convert back to float32 for output
return output.astype(jnp.float32)
# Create mixed precision model
mixed_model = MixedPrecisionModel(
model,
compute_dtype=jnp.float16,
param_dtype=jnp.float32,
)
Dynamic Quantization¤
Quantize activations dynamically at inference time.
def quantize_inference(
model,
x: jax.Array,
num_bits: int = 8,
) -> jax.Array:
"""Run inference with dynamic quantization.
Args:
model: Model to run
x: Input tensor
num_bits: Quantization bits (8 for INT8)
Returns:
Model output
"""
# Quantize input
x_min, x_max = x.min(), x.max()
scale = (x_max - x_min) / (2**num_bits - 1)
x_quant = jnp.round((x - x_min) / scale).astype(jnp.int8)
# Dequantize for computation
x_dequant = x_quant.astype(jnp.float32) * scale + x_min
# Run model
output = model(x_dequant)
return output
INT8 Quantization¤
Full INT8 quantization for maximum efficiency.
class QuantizedLinear(nnx.Module):
"""INT8 quantized linear layer."""
def __init__(
self,
in_features: int,
out_features: int,
*,
rngs: nnx.Rngs,
):
super().__init__()
# Initialize weights in FP32
kernel = jax.random.normal(
rngs.params(),
(in_features, out_features)
) * jnp.sqrt(2.0 / in_features)
# Quantize to INT8
self.scale = jnp.max(jnp.abs(kernel))
self.kernel_quant = jnp.round(
kernel / self.scale * 127
).astype(jnp.int8)
# Bias stays in FP32
self.bias = nnx.Param(jnp.zeros(out_features))
def __call__(self, x: jax.Array) -> jax.Array:
"""Quantized forward pass."""
# Dequantize weights
kernel = self.kernel_quant.astype(jnp.float32) * self.scale / 127
# Compute
output = jnp.dot(x, kernel) + self.bias.value
return output
Memory Optimization¤
Gradient Checkpointing for Large Models¤
Save memory by recomputing activations during backward pass.
from jax.experimental import checkify
def memory_efficient_inference(model, x: jax.Array) -> jax.Array:
"""Inference with gradient checkpointing (saves memory).
Args:
model: Large generative model
x: Input
Returns:
Model output with reduced memory usage
"""
# Use checkpoint to reduce memory
@jax.checkpoint
def checkpointed_forward(x_inner):
return model(x_inner)
return checkpointed_forward(x)
Batch Size Tuning¤
Find optimal batch size for your hardware.
def find_optimal_batch_size(
model,
input_shape: tuple,
max_batch_size: int = 128,
) -> int:
"""Find largest batch size that fits in memory.
Args:
model: Model to test
input_shape: Shape of single input
max_batch_size: Maximum batch size to try
Returns:
Optimal batch size
"""
device_manager = DeviceManager()
device = device_manager.get_device()
batch_size = 1
while batch_size <= max_batch_size:
try:
# Try inference with this batch size
x = jax.random.normal(
jax.random.key(0),
(batch_size, *input_shape)
)
# Move to device
x = jax.device_put(x, device)
# Run inference
_ = model(x)
# Success - try larger batch
batch_size *= 2
except RuntimeError as e:
if "out of memory" in str(e).lower():
# OOM - return previous batch size
return batch_size // 2
else:
raise
return max_batch_size
Activation Compression¤
Reduce memory by compressing intermediate activations.
class CompressedModel(nnx.Module):
"""Model with activation compression."""
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with activation compression."""
# Process in chunks to reduce peak memory
chunk_size = 16
batch_size = x.shape[0]
outputs = []
for i in range(0, batch_size, chunk_size):
chunk = x[i:i + chunk_size]
chunk_output = self.base_model(chunk)
outputs.append(chunk_output)
return jnp.concatenate(outputs, axis=0)
Batching Strategies¤
Dynamic Batching¤
Accumulate requests into batches for efficiency.
import asyncio
from collections import deque
class DynamicBatcher:
"""Dynamic batching for inference requests."""
def __init__(
self,
model,
max_batch_size: int = 32,
timeout_ms: int = 100,
):
self.model = model
self.max_batch_size = max_batch_size
self.timeout_ms = timeout_ms
# Request queue
self.queue = deque()
self.processing = False
async def infer_async(self, x: jax.Array) -> jax.Array:
"""Submit inference request.
Args:
x: Input tensor (single sample)
Returns:
Model output
"""
# Create future for result
future = asyncio.Future()
self.queue.append((x, future))
# Start processing if not already running
if not self.processing:
asyncio.create_task(self._process_batch())
# Wait for result
return await future
async def _process_batch(self):
"""Process batched requests."""
self.processing = True
# Wait for timeout or max batch size
await asyncio.sleep(self.timeout_ms / 1000)
if not self.queue:
self.processing = False
return
# Collect batch
batch_inputs = []
futures = []
while self.queue and len(batch_inputs) < self.max_batch_size:
x, future = self.queue.popleft()
batch_inputs.append(x)
futures.append(future)
# Run batched inference
batch_x = jnp.stack(batch_inputs)
batch_output = self.model(batch_x)
# Return results
for i, future in enumerate(futures):
future.set_result(batch_output[i])
self.processing = False
# Process remaining requests
if self.queue:
asyncio.create_task(self._process_batch())
Pipeline Parallelism¤
Split large models across multiple devices.
class PipelineParallelModel(nnx.Module):
"""Model with pipeline parallelism."""
def __init__(
self,
encoder,
decoder,
devices: list,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device_encoder = devices[0]
self.device_decoder = devices[1]
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with pipeline parallelism."""
# Encoder on device 0
x = jax.device_put(x, self.device_encoder)
latent = self.encoder(x)
# Transfer to device 1
latent = jax.device_put(latent, self.device_decoder)
output = self.decoder(latent)
return output
# Create pipeline parallel model
device_manager = DeviceManager()
devices = device_manager.get_devices()
if len(devices) >= 2:
pipeline_model = PipelineParallelModel(
encoder=model.encoder,
decoder=model.decoder,
devices=devices[:2],
)
Hardware-Specific Optimization¤
GPU Optimization¤
Optimize for NVIDIA GPUs.
def optimize_for_gpu(model):
"""Apply GPU-specific optimizations.
Args:
model: Model to optimize
Returns:
Optimized model
"""
# Use XLA optimizations
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_fast_min_max=true '
'--xla_gpu_enable_triton_gemm=true '
'--xla_gpu_triton_gemm_any=true'
)
# Enable tensor cores for mixed precision
jax.config.update('jax_enable_x64', False)
# JIT compile with aggressive optimization
@jax.jit
def optimized_forward(x):
return model(x)
return optimized_forward
TPU Optimization¤
Optimize for Google TPUs.
def optimize_for_tpu(model, num_devices: int = 8):
"""Apply TPU-specific optimizations.
Args:
model: Model to optimize
num_devices: Number of TPU cores
Returns:
Optimized model with data parallelism
"""
# Replicate model across TPU cores
@jax.pmap
def pmap_forward(x):
return model(x)
# Shard inputs across devices
def inference(x):
# Reshape to (num_devices, batch_per_device, ...)
batch_size = x.shape[0]
batch_per_device = batch_size // num_devices
x_sharded = x.reshape(
(num_devices, batch_per_device) + x.shape[1:]
)
# Run on all TPU cores
output_sharded = pmap_forward(x_sharded)
# Concatenate results
return output_sharded.reshape((batch_size,) + output_sharded.shape[2:])
return inference
Multi-Device Inference¤
Distribute inference across multiple devices.
class MultiDeviceInference:
"""Inference distributed across multiple devices."""
def __init__(self, model):
self.model = model
self.device_manager = DeviceManager()
self.devices = self.device_manager.get_devices()
self.num_devices = len(self.devices)
# Replicate model on all devices
self.replicated_model = jax.pmap(
lambda x: model(x)
)
def infer_distributed(self, x: jax.Array) -> jax.Array:
"""Run inference across all devices.
Args:
x: Input batch (must be divisible by num_devices)
Returns:
Model output
"""
batch_size = x.shape[0]
assert batch_size % self.num_devices == 0
# Reshape for pmap
batch_per_device = batch_size // self.num_devices
x_sharded = x.reshape(
(self.num_devices, batch_per_device) + x.shape[1:]
)
# Run on all devices
output_sharded = self.replicated_model(x_sharded)
# Merge results
return output_sharded.reshape((batch_size,) + output_sharded.shape[2:])
Performance Benchmarking¤
Throughput Measurement¤
import time
def benchmark_throughput(
model,
input_shape: tuple,
batch_sizes: list[int],
num_iterations: int = 100,
):
"""Benchmark model throughput.
Args:
model: Model to benchmark
input_shape: Shape of single input
batch_sizes: Batch sizes to test
num_iterations: Number of iterations per test
Returns:
Dictionary with throughput results
"""
results = {}
for batch_size in batch_sizes:
# Create dummy input
x = jax.random.normal(
jax.random.key(0),
(batch_size, *input_shape)
)
# Warmup
for _ in range(10):
_ = model(x)
# Benchmark
jax.block_until_ready(model(x)) # Ensure GPU work completes
start = time.time()
for _ in range(num_iterations):
output = model(x)
jax.block_until_ready(output)
elapsed = time.time() - start
# Compute metrics
samples_per_sec = (batch_size * num_iterations) / elapsed
latency_ms = (elapsed / num_iterations) * 1000
results[batch_size] = {
'throughput': samples_per_sec,
'latency_ms': latency_ms,
'batch_size': batch_size,
}
return results
Best Practices¤
DO¤
Recommended Optimizations
✅ Always JIT compile production inference functions
✅ Use mixed precision (FP16) on modern GPUs for 2-4x speedup
✅ Batch requests to maximize GPU utilization
✅ Cache compiled functions across server restarts
✅ Profile memory usage to find optimal batch size
✅ Use pmap for multi-device inference when available
DON'T¤
Avoid These Mistakes
❌ Don't recompile on every request (cache JIT functions)
❌ Don't use FP64 unless absolutely necessary (2x slower)
❌ Don't ignore batch size (single samples waste resources)
❌ Don't mix device types (CPU/GPU transfers are slow)
❌ Don't skip warmup (first call is always slow)
❌ Don't quantize without testing quality impact
Troubleshooting¤
Common Performance Issues¤
| Issue | Symptom | Solution |
|---|---|---|
| Slow first inference | 10-60s delay on first call | Warmup JIT compilation during startup |
| OOM errors | Out of memory during inference | Reduce batch size or use gradient checkpointing |
| Low GPU utilization | GPU usage < 50% | Increase batch size or use pipelining |
| Recompilation on every call | Consistent slow performance | Use static shapes or cache per-shape functions |
| Slow multi-device | Linear speedup not achieved | Check data transfer overhead, use pmap |
| High latency | Individual requests take too long | Use dynamic batching or reduce model size |
Summary¤
Effective optimization can provide 10-100x speedups:
- JIT Compilation: Essential for production (10-100x faster)
- Mixed Precision: 2-4x throughput on modern GPUs
- Quantization: Reduce memory 2-4x with minimal quality loss
- Batching: Near-linear scaling with batch size
- Multi-Device: Distribute large workloads across GPUs/TPUs
Choose optimizations based on your deployment constraints and quality requirements.
Next Steps¤
-
Sampling Methods
Learn advanced sampling techniques for quality generation
-
Deployment
Deploy optimized models to production
-
Benchmarking
Measure and compare model performance
-
:material-gpu:{ .lg .middle } Distributed Training
Scale training across multiple devices