Skip to content

Model Deployment¤

Guide to deploying trained Artifex models for inference in production environments.

Overview¤

This guide covers strategies for deploying generative models for production inference.


Model Export¤

Save and load trained models for deployment.

from artifex.generative_models.core import save_model, load_model

# Save trained model
save_model(model, "checkpoints/vae_model")

# Load for inference
model = load_model("checkpoints/vae_model")

Inference Optimization¤

Optimize models for production inference speed.

import jax

# JIT compile for faster inference
@jax.jit
def generate(model, rng_key, num_samples):
    return model.sample(num_samples=num_samples, rng=rng_key)

# Warm up JIT compilation
_ = generate(model, jax.random.key(0), 1)

# Fast inference
samples = generate(model, jax.random.key(42), 64)

Serving Patterns¤

Common patterns for serving generative models.

Batch Processing¤

def batch_generate(model, batch_size=64, total_samples=1000):
    """Generate samples in batches for efficiency."""
    samples = []
    for i in range(0, total_samples, batch_size):
        batch = model.sample(
            num_samples=min(batch_size, total_samples - i),
            rng=jax.random.key(i),
        )
        samples.append(batch)
    return jnp.concatenate(samples, axis=0)

API Endpoint¤

from fastapi import FastAPI
import jax.numpy as jnp

app = FastAPI()

@app.post("/generate")
async def generate_samples(num_samples: int = 16):
    samples = model.sample(num_samples=num_samples, rng=jax.random.key(0))
    return {"samples": samples.tolist()}