Skip to content

Model Deployment¤

Deploy trained Workshop models to production environments, including model export, serving, and optimization for inference.

  • Model Export


    Export models for deployment and inference

    Learn more

  • Model Serving


    Serve models with REST APIs and batch processing

    Learn more

  • Optimization


    Optimize models for production performance

    Learn more

  • Cloud Deployment


    Deploy to cloud platforms and containers

    Learn more

Model Export¤

Export trained models for deployment.

Export Model State¤

from workshop.generative_models.core.checkpointing import save_checkpoint
from flax import nnx
import orbax.checkpoint as ocp

# Save trained model
checkpoint_manager, checkpoint_dir = setup_checkpoint_manager(
    base_dir="./models/production/vae_v1"
)

save_checkpoint(checkpoint_manager, model, step=final_step)

print(f"Model exported to {checkpoint_dir}")

Export with Metadata¤

import json
import orbax.checkpoint as ocp
from flax import nnx

def export_model_with_metadata(
    model: nnx.Module,
    config: dict,
    metrics: dict,
    export_dir: str,
):
    """Export model with configuration and metrics."""
    # Save model checkpoint
    checkpoint_manager, _ = setup_checkpoint_manager(export_dir)
    save_checkpoint(checkpoint_manager, model, step=0)

    # Save configuration
    config_path = f"{export_dir}/config.json"
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)

    # Save metrics
    metrics_path = f"{export_dir}/metrics.json"
    with open(metrics_path, "w") as f:
        json.dump(metrics, f, indent=2)

    # Save model info
    info = {
        "model_type": config.get("model_type"),
        "input_shape": config.get("input_shape"),
        "latent_dim": config.get("latent_dim"),
        "trained_steps": metrics.get("total_steps"),
        "final_loss": float(metrics.get("final_loss", 0.0)),
    }

    info_path = f"{export_dir}/model_info.json"
    with open(info_path, "w") as f:
        json.dump(info, f, indent=2)

    return export_dir

# Export
export_dir = export_model_with_metadata(
    model=model,
    config={"model_type": "vae", "latent_dim": 20, "input_shape": [28, 28, 1]},
    metrics={"total_steps": 10000, "final_loss": 0.15},
    export_dir="./models/production/vae_v1",
)

Load Exported Model¤

import json
from workshop.generative_models.core.checkpointing import (
    setup_checkpoint_manager,
    load_checkpoint,
)

def load_exported_model(export_dir: str):
    """Load exported model with metadata."""
    # Load configuration
    with open(f"{export_dir}/config.json") as f:
        config = json.load(f)

    # Create model template
    from workshop.generative_models.models.vae import create_vae_model
    from workshop.generative_models.core.configuration import ModelConfiguration

    model_config = ModelConfiguration(**config)
    model_template = create_vae_model(model_config, rngs=nnx.Rngs(0))

    # Load checkpoint
    checkpoint_manager, _ = setup_checkpoint_manager(export_dir)
    model, step = load_checkpoint(checkpoint_manager, model_template)

    # Load metrics
    with open(f"{export_dir}/metrics.json") as f:
        metrics = json.load(f)

    return model, config, metrics

# Load
model, config, metrics = load_exported_model("./models/production/vae_v1")
print(f"Loaded model trained for {metrics['total_steps']} steps")

Model Serving¤

Serve models for inference requests.

Simple Inference Server¤

from flask import Flask, request, jsonify
import jax.numpy as jnp
import numpy as np
from PIL import Image
import io

app = Flask(__name__)

# Load model at startup
model, config, metrics = load_exported_model("./models/production/vae_v1")

@app.route("/health", methods=["GET"])
def health():
    """Health check endpoint."""
    return jsonify({"status": "healthy", "model_loaded": model is not None})

@app.route("/generate", methods=["POST"])
def generate():
    """Generate samples from the model."""
    data = request.get_json()

    # Get parameters
    num_samples = data.get("num_samples", 1)
    seed = data.get("seed", 0)

    # Generate samples
    key = jax.random.key(seed)
    z = jax.random.normal(key, (num_samples, config["latent_dim"]))

    samples = model.decode(z)
    samples = np.array(samples)

    # Convert to list for JSON
    samples_list = samples.tolist()

    return jsonify({
        "samples": samples_list,
        "num_samples": num_samples,
        "shape": list(samples.shape),
    })

@app.route("/encode", methods=["POST"])
def encode():
    """Encode image to latent representation."""
    # Get image from request
    file = request.files["image"]
    image = Image.open(io.BytesIO(file.read()))

    # Preprocess
    image_array = np.array(image).astype(np.float32) / 255.0
    image_array = image_array.reshape(1, *config["input_shape"])

    # Encode
    output = model.encode(jnp.array(image_array))
    latent = np.array(output["mean"][0])

    return jsonify({
        "latent": latent.tolist(),
        "latent_dim": len(latent),
    })

@app.route("/reconstruct", methods=["POST"])
def reconstruct():
    """Reconstruct image from input."""
    file = request.files["image"]
    image = Image.open(io.BytesIO(file.read()))

    # Preprocess
    image_array = np.array(image).astype(np.float32) / 255.0
    image_array = image_array.reshape(1, *config["input_shape"])

    # Reconstruct
    output = model(jnp.array(image_array))
    reconstruction = np.array(output["reconstruction"][0])

    return jsonify({
        "reconstruction": reconstruction.tolist(),
        "shape": list(reconstruction.shape),
    })

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8000)

Batch Inference¤

import jax
import jax.numpy as jnp
from typing import Iterator

def batch_inference(
    model: nnx.Module,
    data_iterator: Iterator,
    batch_size: int = 32,
) -> list:
    """Process data in batches for efficient inference."""
    results = []

    batch = []
    for sample in data_iterator:
        batch.append(sample)

        if len(batch) >= batch_size:
            # Process batch
            batch_array = jnp.array(batch)
            output = model(batch_array)

            # Store results
            results.extend(np.array(output["reconstruction"]))

            # Clear batch
            batch = []

    # Process remaining samples
    if batch:
        batch_array = jnp.array(batch)
        output = model(batch_array)
        results.extend(np.array(output["reconstruction"]))

    return results

# Usage
def data_generator():
    """Generator for inference data."""
    for i in range(1000):
        yield np.random.randn(28, 28, 1)

results = batch_inference(model, data_generator(), batch_size=64)
print(f"Processed {len(results)} samples")

Optimization¤

Optimize models for production performance.

JIT Compilation¤

import jax

# JIT-compile inference functions
@jax.jit
def generate_jit(model_state, z):
    """JIT-compiled generation."""
    model = nnx.merge(model_graphdef, model_state)
    return model.decode(z)

@jax.jit
def encode_jit(model_state, x):
    """JIT-compiled encoding."""
    model = nnx.merge(model_graphdef, model_state)
    return model.encode(x)

# Split model once
model_graphdef, model_state = nnx.split(model)

# Fast inference
z = jax.random.normal(jax.random.key(0), (10, 20))
samples = generate_jit(model_state, z)

# First call: compilation + execution (~slow)
# Subsequent calls: cached execution (~fast)

Batched Generation¤

@jax.jit
def batched_generate(model_state, keys):
    """Generate multiple samples in parallel."""
    # Vectorize over batch
    def generate_single(key):
        z = jax.random.normal(key, (latent_dim,))
        model = nnx.merge(model_graphdef, model_state)
        return model.decode(z[None, :])[0]

    # vmap over keys
    samples = jax.vmap(generate_single)(keys)
    return samples

# Generate 100 samples in parallel
keys = jax.random.split(jax.random.key(0), 100)
samples = batched_generate(model_state, keys)

Mixed Precision¤

import jax.numpy as jnp

def convert_to_bfloat16(model_state):
    """Convert model to bfloat16 for faster inference."""
    def to_bf16(x):
        if x.dtype == jnp.float32:
            return x.astype(jnp.bfloat16)
        return x

    return jax.tree.map(to_bf16, model_state)

# Convert model
model_state_bf16 = convert_to_bfloat16(model_state)

# Inference in bfloat16 (2x faster on modern GPUs)
@jax.jit
def generate_bf16(model_state, z):
    z = z.astype(jnp.bfloat16)
    model = nnx.merge(model_graphdef, model_state)
    output = model.decode(z)
    return output.astype(jnp.float32)  # Convert back for output

z = jax.random.normal(jax.random.key(0), (10, 20))
samples = generate_bf16(model_state_bf16, z)

Cloud Deployment¤

Deploy models to cloud platforms.

Docker Deployment¤

# Dockerfile
FROM python:3.10-slim

WORKDIR /app

# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy model and code
COPY models/ models/
COPY serve.py .

# Expose port
EXPOSE 8000

# Run server
CMD ["python", "serve.py"]
# Build image
docker build -t workshop-vae-server .

# Run container
docker run -p 8000:8000 workshop-vae-server

# Test
curl -X POST http://localhost:8000/generate \
  -H "Content-Type: application/json" \
  -d '{"num_samples": 5}'

Kubernetes Deployment¤

# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: workshop-vae-server
spec:
  replicas: 3
  selector:
    matchLabels:
      app: workshop-vae-server
  template:
    metadata:
      labels:
        app: workshop-vae-server
    spec:
      containers:
      - name: server
        image: workshop-vae-server:latest
        ports:
        - containerPort: 8000
        resources:
          requests:
            memory: "2Gi"
            cpu: "1000m"
          limits:
            memory: "4Gi"
            cpu: "2000m"
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
---
apiVersion: v1
kind: Service
metadata:
  name: workshop-vae-service
spec:
  selector:
    app: workshop-vae-server
  ports:
  - protocol: TCP
    port: 80
    targetPort: 8000
  type: LoadBalancer
# Deploy to Kubernetes
kubectl apply -f deployment.yaml

# Check status
kubectl get pods
kubectl get services

Cloud Functions (Serverless)¤

# Google Cloud Function
import functions_framework
import jax.numpy as jnp
import numpy as np

# Load model once (cold start)
model = None

def load_model():
    """Load model on cold start."""
    global model
    if model is None:
        model, _, _ = load_exported_model("gs://my-bucket/models/vae_v1")
    return model

@functions_framework.http
def generate_samples(request):
    """Cloud Function for generation."""
    # Load model
    model = load_model()

    # Parse request
    request_json = request.get_json()
    num_samples = request_json.get("num_samples", 1)
    seed = request_json.get("seed", 0)

    # Generate
    key = jax.random.key(seed)
    z = jax.random.normal(key, (num_samples, 20))
    samples = model.decode(z)

    return {
        "samples": np.array(samples).tolist(),
        "num_samples": num_samples,
    }

Monitoring and Logging¤

Monitor deployed models in production.

Logging Inference Metrics¤

import time
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def monitored_inference(model, input_data):
    """Inference with monitoring."""
    start_time = time.time()

    try:
        # Inference
        output = model(input_data)

        # Log success
        duration = time.time() - start_time
        logger.info(
            f"Inference successful: "
            f"batch_size={input_data.shape[0]}, "
            f"duration={duration:.3f}s"
        )

        return output

    except Exception as e:
        # Log error
        duration = time.time() - start_time
        logger.error(
            f"Inference failed: "
            f"error={str(e)}, "
            f"duration={duration:.3f}s"
        )
        raise

Prometheus Metrics¤

from prometheus_client import Counter, Histogram, start_http_server

# Define metrics
inference_counter = Counter(
    "model_inference_total",
    "Total number of inference requests"
)

inference_duration = Histogram(
    "model_inference_duration_seconds",
    "Inference duration in seconds"
)

inference_errors = Counter(
    "model_inference_errors_total",
    "Total number of inference errors"
)

def monitored_inference_prometheus(model, input_data):
    """Inference with Prometheus metrics."""
    inference_counter.inc()

    with inference_duration.time():
        try:
            output = model(input_data)
            return output
        except Exception as e:
            inference_errors.inc()
            raise

# Start Prometheus metrics server
start_http_server(9090)

Best Practices¤

DO¤

  • Export models with metadata - include config and metrics
  • Use JIT compilation - significant speedup for inference
  • Implement health checks - monitor server status
  • Add logging - track inference requests and errors
  • Use batching - process multiple requests efficiently
  • Set resource limits - prevent out-of-memory errors
  • Version models - track deployed model versions
  • Monitor latency - track inference performance
  • Use load balancers - distribute traffic across replicas
  • Test before deploying - validate in staging environment

DON'T¤

  • Don't skip JIT - leave performance on the table
  • Don't ignore errors - log and handle gracefully
  • Don't process one at a time - use batching
  • Don't deploy without health checks - can't monitor status
  • Don't hardcode configurations - use environment variables
  • Don't skip resource limits - can crash containers
  • Don't deploy untested models - validate first
  • Don't ignore monitoring - can't debug production issues
  • Don't use debug mode - slow and verbose
  • Don't expose internal errors - sanitize error messages

Summary¤

Model deployment in Workshop:

  1. Export: Save models with checkpoints and metadata
  2. Serve: REST APIs, batch processing, cloud functions
  3. Optimize: JIT compilation, batching, mixed precision
  4. Deploy: Docker, Kubernetes, serverless platforms
  5. Monitor: Logging, metrics, health checks

Key considerations:

  • Performance: JIT, batching, mixed precision
  • Reliability: Health checks, error handling, monitoring
  • Scalability: Load balancing, auto-scaling, replicas
  • Maintainability: Versioning, logging, configuration

Next Steps¤

  • Inference Guide


    Learn more about inference and generation

    Inference guide

  • Training Guide


    Return to training documentation

    Training guide

  • API Reference


    Explore the complete API documentation

    API docs

  • FAQ


    Common questions about deployment

    FAQ