Quickstart Guide¤
Get started with Artifex in 5 minutes! This guide walks you through installing Artifex and training your first generative model.
Prerequisites¤
- Python 3.10 or higher
- 8GB RAM (16GB recommended)
- Optional: NVIDIA GPU with CUDA 12.0+ for faster training
Step 1: Install Artifex¤
Choose your preferred installation method:
Verify installation:
python -c "import jax; print(f'JAX backend: {jax.default_backend()}')"
# Should print: JAX backend: gpu (or cpu)
Step 2: Train Your First VAE¤
Create a new Python file train_vae.py:
import jax
import jax.numpy as jnp
import optax
from datarax import from_source
from datarax.core.config import ElementOperatorConfig
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator
from datarax.sources import TfdsDataSourceConfig, TFDSSource
from flax import nnx
from artifex.generative_models.models.vae import VAE
from artifex.generative_models.core.configuration import (
VAEConfig,
EncoderConfig,
DecoderConfig,
)
# 1. Load MNIST with datarax
def normalize(element, _key):
"""Normalize images from [0, 255] to [0, 1]."""
image = element.data["image"].astype(jnp.float32) / 255.0
return element.replace(data={**element.data, "image": image})
source = TFDSSource(
TfdsDataSourceConfig(name="mnist", split="train", shuffle=True),
rngs=nnx.Rngs(0),
)
normalize_op = ElementOperator(
ElementOperatorConfig(stochastic=False), fn=normalize, rngs=nnx.Rngs(1)
)
pipeline = from_source(source, batch_size=32) >> OperatorNode(normalize_op)
# 2. Configure the model with nested configs
encoder = EncoderConfig(
name="mnist_encoder",
input_shape=(28, 28, 1),
latent_dim=32,
hidden_dims=(64, 128),
activation="relu",
)
decoder = DecoderConfig(
name="mnist_decoder",
latent_dim=32,
output_shape=(28, 28, 1),
hidden_dims=(128, 64),
activation="relu",
)
config = VAEConfig(
name="mnist_vae",
encoder=encoder,
decoder=decoder,
encoder_type="dense",
kl_weight=1.0,
)
# 3. Create model and optimizer
rngs = nnx.Rngs(0)
model = VAE(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
print("✓ Model created successfully!")
print(f" Latent dimension: {model.latent_dim}")
state_leaves = jax.tree.leaves(nnx.state(model))
param_count = sum(p.size for p in state_leaves if hasattr(p, 'size'))
print(f" Parameters: ~{param_count/1e6:.2f}M")
# 4. Training step (JIT-compiled for 3-50x speedup)
@nnx.jit # Compiles to XLA for GPU/TPU acceleration
def train_step(model, optimizer, batch):
"""Single training step with automatic differentiation.
JIT compilation provides significant speedups by:
- Fusing operations to reduce memory transfers
- Optimizing computation graphs for target hardware
- Enabling XLA optimizations (constant folding, etc.)
"""
def loss_fn(model):
outputs = model(batch)
loss_dict = model.loss_fn(x=batch, outputs=outputs)
return loss_dict['loss'], loss_dict
(loss, loss_dict), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
optimizer.update(model, grads)
return loss, loss_dict
# 5. Train for one epoch
print("\nTraining for one epoch...")
step = 0
for batch in pipeline:
images = batch["image"]
# JIT-compiled train_step runs ~10-50x faster than eager execution
loss, loss_dict = train_step(model, optimizer, images)
if step % 500 == 0:
print(f"Step {step:5d} | Loss: {loss:.4f} | "
f"Recon: {loss_dict['reconstruction_loss']:.4f} | "
f"KL: {loss_dict['kl_loss']:.4f}")
step += 1
print("\n✓ Training complete!")
# 6. Generate samples
print("\nGenerating samples...")
samples = model.sample(n_samples=8)
print(f"✓ Generated {samples.shape[0]} samples with shape {samples.shape[1:]}")
# 7. Reconstruct an image
test_image = images[:1] # Use last batch's first image
reconstructed = model.reconstruct(test_image, deterministic=True)
print(f"✓ Reconstructed image with shape {reconstructed.shape}")
print("\n🎉 Success! You've trained your first VAE with Artifex!")
Run the script:
Expected output:
✓ Model created successfully!
Latent dimension: 32
Parameters: ~0.18M
Training for one epoch...
Step 0 | Loss: 13.2877 | Recon: 0.2709 | KL: 13.0168
Step 500 | Loss: 0.3743 | Recon: 0.0815 | KL: 0.2928
Step 1000 | Loss: 0.1551 | Recon: 0.0623 | KL: 0.0928
...
✓ Training complete!
Generating samples...
✓ Generated 8 samples with shape (28, 28, 1)
✓ Reconstructed image with shape (1, 28, 28, 1)
🎉 Success! You've trained your first VAE with Artifex!
Step 3: Visualize Results (Optional)¤
Add visualization to your script:
import matplotlib.pyplot as plt
# Visualize generated samples
fig, axes = plt.subplots(2, 4, figsize=(10, 5))
for i, ax in enumerate(axes.flat):
ax.imshow(samples[i].squeeze(), cmap='gray')
ax.axis('off')
plt.suptitle('Generated Samples from VAE')
plt.tight_layout()
plt.savefig('vae_samples.png')
print("✓ Saved samples to vae_samples.png")
# Visualize reconstruction
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(test_image[0].squeeze(), cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(reconstructed[0].squeeze(), cmap='gray')
axes[1].set_title('Reconstructed')
axes[1].axis('off')
plt.tight_layout()
plt.savefig('vae_reconstruction.png')
print("✓ Saved reconstruction to vae_reconstruction.png")
Generated VAE Samples:

Original vs Reconstructed:

What You Just Did¤
In just a few minutes, you:
- ✅ Installed Artifex - Set up the complete environment
- ✅ Created a VAE - Built a variational autoencoder with configuration
- ✅ Trained the model - Ran 100 training steps with loss monitoring
- ✅ Generated samples - Created new images from the learned distribution
- ✅ Reconstructed images - Tested the encoder-decoder pipeline
Key Concepts¤
Configuration System¤
config = ModelConfig(
model_type="vae", # Model type
latent_dim=32, # Latent space dimension
input_shape=(28, 28, 1), # Input image shape
encoder_features=[64, 128], # Encoder layer sizes
decoder_features=[128, 64], # Decoder layer sizes
parameters={...} # Model-specific parameters
)
Artifex uses a unified configuration system based on frozen dataclasses for type-safe, validated configurations.
Direct Model Instantiation¤
Models are created directly from their configuration objects, providing full control and type safety.
RNG Management¤
JAX requires explicit random number generators for reproducibility and functional purity.
Training Loop¤
The training loop follows standard JAX/Flax patterns with optimizations:
- Forward pass:
model(batch)→ outputs - Compute loss:
model.loss_fn()→ loss values - Backward pass:
nnx.value_and_grad()→ gradients - Update weights:
optimizer.update()→ new parameters
Performance Tips:
- Use
@nnx.jiton train_step for 10-50x speedup - For large models, use
@nnx.jit(donate_argnums=(1,))to donate optimizer memory - Avoid Python control flow inside JIT functions (use
jax.lax.condinstead)
Try Different Models¤
Train a Diffusion Model¤
import jax
import jax.numpy as jnp
import optax
from datarax import from_source
from datarax.core.config import ElementOperatorConfig
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator
from datarax.sources import TfdsDataSourceConfig, TFDSSource
from flax import nnx
from artifex.generative_models.models.diffusion import DDPMModel
from artifex.generative_models.core.configuration import (
DDPMConfig,
UNetBackboneConfig,
NoiseScheduleConfig,
)
from artifex.generative_models.training.trainers import (
DiffusionTrainer,
DiffusionTrainingConfig,
)
# 1. Load Fashion-MNIST with datarax
def normalize(element, _key):
"""Normalize images to [-1, 1] for diffusion models."""
image = element.data["image"].astype(jnp.float32) / 127.5 - 1.0
return element.replace(data={**element.data, "image": image})
source = TFDSSource(
TfdsDataSourceConfig(name="fashion_mnist", split="train", shuffle=True),
rngs=nnx.Rngs(0),
)
normalize_op = ElementOperator(
ElementOperatorConfig(stochastic=False), fn=normalize, rngs=nnx.Rngs(1)
)
pipeline = from_source(source, batch_size=64) >> OperatorNode(normalize_op)
# 2. Create DDPM configuration
backbone = UNetBackboneConfig(
name="unet_backbone",
in_channels=1,
out_channels=1,
hidden_dims=(32, 64, 128),
channel_mult=(1, 2, 4),
activation="silu",
)
noise_schedule = NoiseScheduleConfig(
name="cosine_schedule",
schedule_type="cosine",
num_timesteps=1000,
beta_start=1e-4,
beta_end=2e-2,
)
config = DDPMConfig(
name="fashion_ddpm",
input_shape=(28, 28, 1), # HWC format
backbone=backbone,
noise_schedule=noise_schedule,
)
# 3. Create model and optimizer
rngs = nnx.Rngs(42)
model = DDPMModel(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adamw(1e-4), wrt=nnx.Param)
# 4. Configure trainer with SOTA techniques (min-SNR weighting, EMA)
trainer = DiffusionTrainer(
noise_schedule=model.noise_schedule,
config=DiffusionTrainingConfig(
loss_weighting="min_snr", # Min-SNR weighting for faster convergence
snr_gamma=5.0,
ema_decay=0.9999,
),
)
# JIT-compile the train_step for performance
jit_train_step = nnx.jit(trainer.train_step)
# 5. Training loop
rng = jax.random.PRNGKey(0)
step = 0
for batch in pipeline:
rng, step_rng = jax.random.split(rng)
_, metrics = jit_train_step(model, optimizer, {"image": batch["image"]}, step_rng)
trainer.update_ema(model) # EMA updates outside JIT
if step % 100 == 0:
print(f"Step {step}: loss={metrics['loss']:.4f}")
step += 1
# 6. Generate samples
samples = model.sample(n_samples_or_shape=8, steps=100)
print(f"Generated samples shape: {samples.shape}")
Train a GAN¤
from flax import nnx
from artifex.generative_models.models.gan import DCGAN
from artifex.generative_models.core.configuration import (
DCGANConfig,
ConvGeneratorConfig,
ConvDiscriminatorConfig,
)
# Create DCGAN configuration with convolutional networks
generator = ConvGeneratorConfig(
name="dcgan_generator",
latent_dim=100,
hidden_dims=(512, 256, 128, 64),
output_shape=(1, 28, 28), # CHW format
activation="relu",
batch_norm=True,
kernel_size=(4, 4),
stride=(2, 2),
padding="SAME",
)
discriminator = ConvDiscriminatorConfig(
name="dcgan_discriminator",
hidden_dims=(64, 128, 256, 512),
input_shape=(1, 28, 28), # CHW format
activation="leaky_relu",
leaky_relu_slope=0.2,
batch_norm=True,
kernel_size=(4, 4),
stride=(2, 2),
padding="SAME",
)
config = DCGANConfig(
name="mnist_dcgan",
generator=generator,
discriminator=discriminator,
)
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
model = DCGAN(config, rngs=rngs)
print(f"✓ DCGAN created with latent_dim={config.generator.latent_dim}")
Next Steps¤
Now that you have a working setup, explore more:
-
Learn Core Concepts
Understand generative modeling fundamentals and Artifex architecture
-
Build Your First Model
Step-by-step tutorial to build a VAE from scratch with real data
-
Explore Model Guides
Deep dive into VAEs, GANs, Diffusion, Flows, and more
-
Check Examples
Ready-to-run examples for various models and use cases
Common Next Questions¤
How do I use real data?¤
See the Data Pipeline Guide for loading CIFAR-10, ImageNet, and custom datasets.
How do I save and load models?¤
# Save
from flax.training import checkpoints
checkpoints.save_checkpoint('checkpoints/', model, step=100)
# Load
model = checkpoints.restore_checkpoint('checkpoints/', model)
See Training Guide for details on checkpointing.
How do I train on multiple GPUs?¤
Artifex supports distributed training out of the box. See Distributed Training Guide.
What if I get errors?¤
If you encounter issues, open an issue on GitHub.
Quick Reference¤
Model Types¤
| Type | Model Class | Config Class | Use Case |
|---|---|---|---|
| VAE | VAE |
VAEConfig |
Latent representations, data compression |
| GAN | DCGAN, GAN |
DCGANConfig, GANConfig |
High-quality image generation |
| Diffusion | DDPMModel |
DDPMConfig |
State-of-the-art generation, controllable |
| Flow | FlowModel |
FlowConfig |
Exact likelihood, invertible transformations |
| EBM | EnergyBasedModel |
EBMConfig |
Energy-based modeling, composable |
Key Commands¤
# Install
uv sync --all-extras
# Run tests
pytest tests/ -v
# Format code
ruff format src/
# Type check
pyright src/
# Build docs
mkdocs serve
Getting Help¤
- Documentation: Comprehensive guides and API reference
- Examples: Ready-to-run code in
examples/ - Issues: GitHub Issues
- Discussions: GitHub Discussions
Congratulations! 🎉 You've completed the quickstart guide. You're now ready to build more sophisticated generative models with Artifex!
Next recommended step: Core Concepts to understand the architecture better, or First Model Tutorial to build a complete VAE with real data.