Skip to content

Quickstart Guide¤

Get started with Artifex in 5 minutes! This guide walks you through installing Artifex and training your first generative model.

Prerequisites¤

  • Python 3.10 or higher
  • 8GB RAM (16GB recommended)
  • Optional: NVIDIA GPU with CUDA 12.0+ for faster training

Step 1: Install Artifex¤

Choose your preferred installation method:

# Clone repository
git clone https://github.com/avitai/artifex.git
cd artifex

# Install with uv (fastest)
uv venv && source .venv/bin/activate
uv sync --all-extras

# Or with pip
python -m venv .venv && source .venv/bin/activate
pip install -e '.[dev]'
pip install artifex

Verify installation:

python -c "import jax; print(f'JAX backend: {jax.default_backend()}')"
# Should print: JAX backend: gpu (or cpu)

Step 2: Train Your First VAE¤

Create a new Python file train_vae.py:

import jax
import jax.numpy as jnp
import optax
from datarax import from_source
from datarax.core.config import ElementOperatorConfig
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator
from datarax.sources import TfdsDataSourceConfig, TFDSSource
from flax import nnx

from artifex.generative_models.models.vae import VAE
from artifex.generative_models.core.configuration import (
    VAEConfig,
    EncoderConfig,
    DecoderConfig,
)

# 1. Load MNIST with datarax
def normalize(element, _key):
    """Normalize images from [0, 255] to [0, 1]."""
    image = element.data["image"].astype(jnp.float32) / 255.0
    return element.replace(data={**element.data, "image": image})

source = TFDSSource(
    TfdsDataSourceConfig(name="mnist", split="train", shuffle=True),
    rngs=nnx.Rngs(0),
)
normalize_op = ElementOperator(
    ElementOperatorConfig(stochastic=False), fn=normalize, rngs=nnx.Rngs(1)
)
pipeline = from_source(source, batch_size=32) >> OperatorNode(normalize_op)

# 2. Configure the model with nested configs
encoder = EncoderConfig(
    name="mnist_encoder",
    input_shape=(28, 28, 1),
    latent_dim=32,
    hidden_dims=(64, 128),
    activation="relu",
)

decoder = DecoderConfig(
    name="mnist_decoder",
    latent_dim=32,
    output_shape=(28, 28, 1),
    hidden_dims=(128, 64),
    activation="relu",
)

config = VAEConfig(
    name="mnist_vae",
    encoder=encoder,
    decoder=decoder,
    encoder_type="dense",
    kl_weight=1.0,
)

# 3. Create model and optimizer
rngs = nnx.Rngs(0)
model = VAE(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

print("✓ Model created successfully!")
print(f"  Latent dimension: {model.latent_dim}")
state_leaves = jax.tree.leaves(nnx.state(model))
param_count = sum(p.size for p in state_leaves if hasattr(p, 'size'))
print(f"  Parameters: ~{param_count/1e6:.2f}M")

# 4. Training step (JIT-compiled for 3-50x speedup)
@nnx.jit  # Compiles to XLA for GPU/TPU acceleration
def train_step(model, optimizer, batch):
    """Single training step with automatic differentiation.

    JIT compilation provides significant speedups by:
    - Fusing operations to reduce memory transfers
    - Optimizing computation graphs for target hardware
    - Enabling XLA optimizations (constant folding, etc.)
    """
    def loss_fn(model):
        outputs = model(batch)
        loss_dict = model.loss_fn(x=batch, outputs=outputs)
        return loss_dict['loss'], loss_dict

    (loss, loss_dict), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    return loss, loss_dict

# 5. Train for one epoch
print("\nTraining for one epoch...")
step = 0

for batch in pipeline:
    images = batch["image"]

    # JIT-compiled train_step runs ~10-50x faster than eager execution
    loss, loss_dict = train_step(model, optimizer, images)

    if step % 500 == 0:
        print(f"Step {step:5d} | Loss: {loss:.4f} | "
              f"Recon: {loss_dict['reconstruction_loss']:.4f} | "
              f"KL: {loss_dict['kl_loss']:.4f}")
    step += 1

print("\n✓ Training complete!")

# 6. Generate samples
print("\nGenerating samples...")
samples = model.sample(n_samples=8)
print(f"✓ Generated {samples.shape[0]} samples with shape {samples.shape[1:]}")

# 7. Reconstruct an image
test_image = images[:1]  # Use last batch's first image
reconstructed = model.reconstruct(test_image, deterministic=True)
print(f"✓ Reconstructed image with shape {reconstructed.shape}")

print("\n🎉 Success! You've trained your first VAE with Artifex!")

Run the script:

python train_vae.py

Expected output:

✓ Model created successfully!
  Latent dimension: 32
  Parameters: ~0.18M

Training for one epoch...
Step     0 | Loss: 13.2877 | Recon: 0.2709 | KL: 13.0168
Step   500 | Loss: 0.3743 | Recon: 0.0815 | KL: 0.2928
Step  1000 | Loss: 0.1551 | Recon: 0.0623 | KL: 0.0928
...

✓ Training complete!

Generating samples...
✓ Generated 8 samples with shape (28, 28, 1)
✓ Reconstructed image with shape (1, 28, 28, 1)

🎉 Success! You've trained your first VAE with Artifex!

Step 3: Visualize Results (Optional)¤

Add visualization to your script:

import matplotlib.pyplot as plt

# Visualize generated samples
fig, axes = plt.subplots(2, 4, figsize=(10, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i].squeeze(), cmap='gray')
    ax.axis('off')
plt.suptitle('Generated Samples from VAE')
plt.tight_layout()
plt.savefig('vae_samples.png')
print("✓ Saved samples to vae_samples.png")

# Visualize reconstruction
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(test_image[0].squeeze(), cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(reconstructed[0].squeeze(), cmap='gray')
axes[1].set_title('Reconstructed')
axes[1].axis('off')
plt.tight_layout()
plt.savefig('vae_reconstruction.png')
print("✓ Saved reconstruction to vae_reconstruction.png")

Generated VAE Samples:

VAE Generated Samples

Original vs Reconstructed:

VAE Reconstruction

What You Just Did¤

In just a few minutes, you:

  1. Installed Artifex - Set up the complete environment
  2. Created a VAE - Built a variational autoencoder with configuration
  3. Trained the model - Ran 100 training steps with loss monitoring
  4. Generated samples - Created new images from the learned distribution
  5. Reconstructed images - Tested the encoder-decoder pipeline

Key Concepts¤

Configuration System¤

config = ModelConfig(
    model_type="vae",           # Model type
    latent_dim=32,              # Latent space dimension
    input_shape=(28, 28, 1),    # Input image shape
    encoder_features=[64, 128], # Encoder layer sizes
    decoder_features=[128, 64], # Decoder layer sizes
    parameters={...}            # Model-specific parameters
)

Artifex uses a unified configuration system based on frozen dataclasses for type-safe, validated configurations.

Direct Model Instantiation¤

model = VAE(config, rngs=rngs)

Models are created directly from their configuration objects, providing full control and type safety.

RNG Management¤

rngs = nnx.Rngs(0)  # Random number generator

JAX requires explicit random number generators for reproducibility and functional purity.

Training Loop¤

The training loop follows standard JAX/Flax patterns with optimizations:

  1. Forward pass: model(batch) → outputs
  2. Compute loss: model.loss_fn() → loss values
  3. Backward pass: nnx.value_and_grad() → gradients
  4. Update weights: optimizer.update() → new parameters

Performance Tips:

  • Use @nnx.jit on train_step for 10-50x speedup
  • For large models, use @nnx.jit(donate_argnums=(1,)) to donate optimizer memory
  • Avoid Python control flow inside JIT functions (use jax.lax.cond instead)

Try Different Models¤

Train a Diffusion Model¤

import jax
import jax.numpy as jnp
import optax
from datarax import from_source
from datarax.core.config import ElementOperatorConfig
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator
from datarax.sources import TfdsDataSourceConfig, TFDSSource
from flax import nnx

from artifex.generative_models.models.diffusion import DDPMModel
from artifex.generative_models.core.configuration import (
    DDPMConfig,
    UNetBackboneConfig,
    NoiseScheduleConfig,
)
from artifex.generative_models.training.trainers import (
    DiffusionTrainer,
    DiffusionTrainingConfig,
)

# 1. Load Fashion-MNIST with datarax
def normalize(element, _key):
    """Normalize images to [-1, 1] for diffusion models."""
    image = element.data["image"].astype(jnp.float32) / 127.5 - 1.0
    return element.replace(data={**element.data, "image": image})

source = TFDSSource(
    TfdsDataSourceConfig(name="fashion_mnist", split="train", shuffle=True),
    rngs=nnx.Rngs(0),
)
normalize_op = ElementOperator(
    ElementOperatorConfig(stochastic=False), fn=normalize, rngs=nnx.Rngs(1)
)
pipeline = from_source(source, batch_size=64) >> OperatorNode(normalize_op)

# 2. Create DDPM configuration
backbone = UNetBackboneConfig(
    name="unet_backbone",
    in_channels=1,
    out_channels=1,
    hidden_dims=(32, 64, 128),
    channel_mult=(1, 2, 4),
    activation="silu",
)

noise_schedule = NoiseScheduleConfig(
    name="cosine_schedule",
    schedule_type="cosine",
    num_timesteps=1000,
    beta_start=1e-4,
    beta_end=2e-2,
)

config = DDPMConfig(
    name="fashion_ddpm",
    input_shape=(28, 28, 1),  # HWC format
    backbone=backbone,
    noise_schedule=noise_schedule,
)

# 3. Create model and optimizer
rngs = nnx.Rngs(42)
model = DDPMModel(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=nnx.Param)

# 4. Configure trainer with SOTA techniques (min-SNR weighting, EMA)
trainer = DiffusionTrainer(
    noise_schedule=model.noise_schedule,
    config=DiffusionTrainingConfig(
        loss_weighting="min_snr",  # Min-SNR weighting for faster convergence
        snr_gamma=5.0,
        ema_decay=0.9999,
    ),
)

# JIT-compile the train_step for performance
jit_train_step = nnx.jit(trainer.train_step)

# 5. Training loop
rng = jax.random.PRNGKey(0)
step = 0

for batch in pipeline:
    rng, step_rng = jax.random.split(rng)
    _, metrics = jit_train_step(model, optimizer, {"image": batch["image"]}, step_rng)
    trainer.update_ema(model)  # EMA updates outside JIT

    if step % 100 == 0:
        print(f"Step {step}: loss={metrics['loss']:.4f}")
    step += 1

# 6. Generate samples
samples = model.sample(n_samples_or_shape=8, steps=100)
print(f"Generated samples shape: {samples.shape}")

Train a GAN¤

from flax import nnx
from artifex.generative_models.models.gan import DCGAN
from artifex.generative_models.core.configuration import (
    DCGANConfig,
    ConvGeneratorConfig,
    ConvDiscriminatorConfig,
)

# Create DCGAN configuration with convolutional networks
generator = ConvGeneratorConfig(
    name="dcgan_generator",
    latent_dim=100,
    hidden_dims=(512, 256, 128, 64),
    output_shape=(1, 28, 28),  # CHW format
    activation="relu",
    batch_norm=True,
    kernel_size=(4, 4),
    stride=(2, 2),
    padding="SAME",
)

discriminator = ConvDiscriminatorConfig(
    name="dcgan_discriminator",
    hidden_dims=(64, 128, 256, 512),
    input_shape=(1, 28, 28),  # CHW format
    activation="leaky_relu",
    leaky_relu_slope=0.2,
    batch_norm=True,
    kernel_size=(4, 4),
    stride=(2, 2),
    padding="SAME",
)

config = DCGANConfig(
    name="mnist_dcgan",
    generator=generator,
    discriminator=discriminator,
)

rngs = nnx.Rngs(params=0, dropout=1, sample=2)
model = DCGAN(config, rngs=rngs)
print(f"✓ DCGAN created with latent_dim={config.generator.latent_dim}")

Next Steps¤

Now that you have a working setup, explore more:

  • Learn Core Concepts


    Understand generative modeling fundamentals and Artifex architecture

    Core Concepts

  • Build Your First Model


    Step-by-step tutorial to build a VAE from scratch with real data

    First Model Tutorial

  • Explore Model Guides


    Deep dive into VAEs, GANs, Diffusion, Flows, and more

    VAE Guide Model Implementations

  • Check Examples


    Ready-to-run examples for various models and use cases

    Examples

Common Next Questions¤

How do I use real data?¤

See the Data Pipeline Guide for loading CIFAR-10, ImageNet, and custom datasets.

How do I save and load models?¤

# Save
from flax.training import checkpoints
checkpoints.save_checkpoint('checkpoints/', model, step=100)

# Load
model = checkpoints.restore_checkpoint('checkpoints/', model)

See Training Guide for details on checkpointing.

How do I train on multiple GPUs?¤

Artifex supports distributed training out of the box. See Distributed Training Guide.

What if I get errors?¤

If you encounter issues, open an issue on GitHub.

Quick Reference¤

Model Types¤

Type Model Class Config Class Use Case
VAE VAE VAEConfig Latent representations, data compression
GAN DCGAN, GAN DCGANConfig, GANConfig High-quality image generation
Diffusion DDPMModel DDPMConfig State-of-the-art generation, controllable
Flow FlowModel FlowConfig Exact likelihood, invertible transformations
EBM EnergyBasedModel EBMConfig Energy-based modeling, composable

Key Commands¤

# Install
uv sync --all-extras

# Run tests
pytest tests/ -v

# Format code
ruff format src/

# Type check
pyright src/

# Build docs
mkdocs serve

Getting Help¤


Congratulations! 🎉 You've completed the quickstart guide. You're now ready to build more sophisticated generative models with Artifex!

Next recommended step: Core Concepts to understand the architecture better, or First Model Tutorial to build a complete VAE with real data.