Core Concepts¤
This guide introduces the fundamental concepts behind Artifex and generative modeling. Understanding these concepts will help you make the most of the library.
What is Generative Modeling?¤
Generative modeling is about learning probability distributions from data and generating new samples from those distributions.
The Core Problem¤
Given a dataset \(\mathcal{D} = \{x_1, x_2, ..., x_N\}\), we want to:
- Learn the underlying data distribution \(p(x)\)
- Generate new samples \(\tilde{x} \sim p(x)\) that look like the training data
- Evaluate the quality of generated samples
Why Generative Models?¤
-
Image Generation
Create realistic images, art, faces, or any visual content
-
Data Augmentation
Generate synthetic training data to improve discriminative models
-
Representation Learning
Learn meaningful latent representations of data
-
Anomaly Detection
Identify outliers by measuring likelihood under the learned distribution
-
Content Creation
Generate text, audio, video, 3D shapes, and more
-
Scientific Discovery
Generate molecules, proteins, materials for drug design and research
Key Concepts¤
1. Probability Distribution¤
A probability distribution \(p(x)\) assigns probabilities to different outcomes:
- Discrete: \(p(x) \in [0, 1]\) for each \(x\), \(\sum_x p(x) = 1\)
- Continuous: \(p(x) \geq 0\), \(\int p(x) dx = 1\)
Goal: Learn \(p(x)\) from data so we can sample new \(x \sim p(x)\).
2. Likelihood¤
The likelihood \(p(x|\theta)\) measures how probable data \(x\) is under model parameters \(\theta\).
Maximum Likelihood Estimation (MLE):
Models with tractable likelihoods (e.g., Flows, Autoregressive) directly optimize this.
3. Latent Variables¤
Latent variables \(z\) are unobserved variables that capture underlying structure:
- \(p(z)\): Prior distribution (usually standard normal)
- \(p(x|z)\): Likelihood (decoder/generator)
- \(p(z|x)\): Posterior (encoder/inference network)
Examples:
- VAE: Continuous latent space
- VQ-VAE: Discrete latent codebook
- Diffusion: Noisy latent trajectory
4. Sampling¤
Generating new data from the learned distribution:
Ancestral Sampling: Sample from the prior \(z \sim p(z)\), then generate \(x \sim p(x|z)\)
MCMC Sampling: Use Markov chains to sample from complex distributions
Diffusion Sampling: Iteratively denoise from pure noise to data
5. Encoder-Decoder Architecture¤
graph LR
X[Data x] -->|Encode| Z[Latent z]
Z -->|Decode| XR[Reconstructed x']
style X fill:#e1f5ff
style Z fill:#fff4e1
style XR fill:#e8f5e9
- Encoder: Maps data to latent space \(q(z|x)\)
- Decoder: Maps latent to data space \(p(x|z)\)
- Latent Space: Compressed, structured representation
Generative Model Types¤
Artifex supports six main types of generative models, each with different trade-offs:
1. Variational Autoencoders (VAE)¤
Idea: Learn latent representations with probabilistic encoding/decoding
Key Equation: Evidence Lower Bound (ELBO)
Architecture:
graph TD
X[Input x] --> E[Encoder q]
E --> M[Mean μ]
E --> S[Std σ]
M --> R[Reparameterize]
S --> R
R --> Z[Latent z]
Z --> D[Decoder p]
D --> XR[Output x']
Pros:
- Stable training
- Interpretable latent space
- Fast sampling
Cons:
- Lower sample quality compared to GANs/Diffusion
- Posterior approximation may be limited
Use Cases: Data compression, latent space exploration, representation learning
Artifex Example:
from flax import nnx
from artifex.generative_models.models.vae import VAE
from artifex.generative_models.core.configuration import VAEConfig, EncoderConfig, DecoderConfig
encoder = EncoderConfig(name="encoder", input_shape=(28, 28, 1), latent_dim=128, hidden_dims=(256, 128), activation="relu")
decoder = DecoderConfig(name="decoder", latent_dim=128, output_shape=(28, 28, 1), hidden_dims=(128, 256), activation="relu")
config = VAEConfig(name="beta_vae", encoder=encoder, decoder=decoder, encoder_type="dense", kl_weight=1.0)
rngs = nnx.Rngs(0)
model = VAE(config, rngs=rngs)
2. Generative Adversarial Networks (GAN)¤
Idea: Train generator and discriminator in adversarial game
Key Equation: Minimax objective
Architecture:
graph LR
Z[Noise z] --> G[Generator G]
G --> XF[Fake x']
XR[Real x] --> D[Discriminator D]
XF --> D
D --> R[Real/Fake]
style G fill:#fff4e1
style D fill:#ffe1f5
Pros:
- Highest sample quality
- No explicit likelihood needed
- Mode coverage (with proper training)
Cons:
- Training instability (mode collapse, vanishing gradients)
- Hyperparameter sensitive
- No direct likelihood evaluation
Use Cases: High-quality image generation, style transfer, image-to-image translation
Artifex Example:
from flax import nnx
from artifex.generative_models.models.gan import DCGAN
from artifex.generative_models.core.configuration import (
DCGANConfig, ConvGeneratorConfig, ConvDiscriminatorConfig
)
generator = ConvGeneratorConfig(
name="generator", latent_dim=100, hidden_dims=(512, 256, 128, 64),
output_shape=(1, 28, 28), kernel_size=(4, 4), stride=(2, 2), padding="SAME",
activation="relu"
)
discriminator = ConvDiscriminatorConfig(
name="discriminator", hidden_dims=(64, 128, 256, 512),
input_shape=(1, 28, 28), kernel_size=(4, 4), stride=(2, 2), padding="SAME",
activation="leaky_relu"
)
config = DCGANConfig(name="dcgan", generator=generator, discriminator=discriminator)
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
model = DCGAN(config, rngs=rngs)
3. Diffusion Models¤
Idea: Learn to denoise data through iterative refinement
Forward Process: Add noise gradually
Reverse Process: Learn to denoise
Architecture:
graph LR
X0[Clean x₀] -->|Add noise| X1[Noisy x₁]
X1 -->|Add noise| X2[Noisy x₂]
X2 -->|...| XT[Pure noise xₜ]
XT -->|Denoise| X2R[x₂]
X2R -->|Denoise| X1R[x₁]
X1R -->|Denoise| X0R[Clean x₀]
style X0 fill:#e8f5e9
style XT fill:#ffebee
style X0R fill:#e8f5e9
Pros:
- State-of-the-art sample quality
- Stable training
- Flexible conditioning and guidance
Cons:
- Slow sampling (many steps)
- Memory intensive
- Training time
Use Cases: Image/audio generation, inpainting, super-resolution, conditional generation
Artifex Example:
from flax import nnx
from artifex.generative_models.models.diffusion import DDPMModel
from artifex.generative_models.core.configuration import DDPMConfig, UNetBackboneConfig, NoiseScheduleConfig
backbone = UNetBackboneConfig(name="backbone", in_channels=1, out_channels=1, hidden_dims=(64, 128), channel_mult=(1, 2), activation="silu")
noise_schedule = NoiseScheduleConfig(name="schedule", schedule_type="cosine", num_timesteps=1000)
config = DDPMConfig(name="ddpm", input_shape=(28, 28, 1), backbone=backbone, noise_schedule=noise_schedule)
rngs = nnx.Rngs(0)
model = DDPMModel(config, rngs=rngs)
4. Normalizing Flows¤
Idea: Use invertible transformations with tractable Jacobians
Key Property: Change of variables
Architecture:
graph LR
Z[Simple z] -->|f| X[Complex x]
X -->|f⁻¹| Z2[Simple z]
style Z fill:#e8f5e9
style X fill:#e1f5ff
Pros:
- Exact likelihood computation
- Invertible (can go both ways)
- Stable training
Cons:
- Architecture constraints (must be invertible)
- May require many layers for expressiveness
Use Cases: Density estimation, exact inference, anomaly detection
Artifex Example:
from flax import nnx
from artifex.generative_models.models.flow import RealNVP
from artifex.generative_models.core.configuration import RealNVPConfig, CouplingNetworkConfig
coupling = CouplingNetworkConfig(name="coupling", hidden_dims=(256, 256), activation="relu")
config = RealNVPConfig(name="realnvp", input_dim=784, num_coupling_layers=8, coupling_network=coupling)
rngs = nnx.Rngs(0)
model = RealNVP(config, rngs=rngs)
5. Energy-Based Models (EBM)¤
Idea: Learn energy function, sample with MCMC
Key Equation: Gibbs distribution
where
is the partition function.
Pros:
- Flexible, can model any distribution
- Composable (combine multiple EBMs)
Cons:
- Expensive sampling (MCMC)
- Training complexity (contrastive divergence)
Use Cases: Compositional generation, constraint satisfaction, hybrid models
Artifex Example:
from flax import nnx
from artifex.generative_models.models.energy import EBM
from artifex.generative_models.core.configuration import EBMConfig, EnergyNetworkConfig, MCMCConfig
from artifex.generative_models.core.configuration.energy_config import SampleBufferConfig
energy_network = EnergyNetworkConfig(name="energy_net", hidden_dims=(256, 256), activation="swish")
mcmc = MCMCConfig(name="mcmc", n_steps=60, step_size=0.01)
sample_buffer = SampleBufferConfig(name="buffer", capacity=10000)
config = EBMConfig(name="ebm", input_dim=784, energy_network=energy_network, mcmc=mcmc, sample_buffer=sample_buffer)
rngs = nnx.Rngs(0)
model = EBM(config, rngs=rngs)
6. Autoregressive Models¤
Idea: Model distribution as product of conditionals
Architecture:
graph LR
X1[x₁] -->|p| X2[x₂]
X2 -->|p| X3[x₃]
X3 -->|...| XN[xₙ]
Pros:
- Explicit likelihood
- Flexible architectures (Transformers, CNNs)
- Strong theoretical foundation
Cons:
- Sequential generation (slow)
- Fixed ordering required
Use Cases: Text generation, ordered sequences, explicit probability modeling
Artifex Example:
from flax import nnx
from artifex.generative_models.models.autoregressive import PixelCNN
from artifex.generative_models.core.configuration import PixelCNNConfig
config = PixelCNNConfig(name="pixelcnn", image_shape=(28, 28, 1), hidden_channels=64, num_layers=8)
rngs = nnx.Rngs(0)
model = PixelCNN(config, rngs=rngs)
Model Comparison Matrix¤
| Feature | VAE | GAN | Diffusion | Flow | EBM | Autoregressive |
|---|---|---|---|---|---|---|
| Sample Quality | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ |
| Training Stability | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
| Sampling Speed | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐ |
| Exact Likelihood | ❌ | ❌ | ❌ | ✅ | ❌* | ✅ |
| Latent Space | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
| Mode Coverage | ⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
*EBM has exact likelihood but intractable partition function
Artifex Architecture¤
High-Level Design¤
Artifex follows a modular, protocol-based design:
graph TB
subgraph User["User Interface"]
Config[Configuration Classes]
end
subgraph Core["Core Components"]
Protocols[Protocols & Interfaces]
Device[Device Manager]
Loss[Loss Functions]
end
subgraph Models["Generative Models"]
VAE[VAE]
GAN[GAN]
Diff[Diffusion]
Flow[Flow]
EBM[EBM]
AR[Autoregressive]
end
subgraph Training["Training System"]
Trainer[Trainers]
Opt[Optimizers]
Callbacks[Callbacks]
end
Config --> Models
Models --> Training
Core --> Models
Core --> Training
Key Design Principles¤
- Protocol-Based: Type-safe interfaces using Python Protocols
- Configuration-Driven: Pydantic-based unified configuration system
- Factory Pattern: Centralized model creation
- Hardware-Aware: Automatic GPU/CPU/TPU detection and optimization
- Modular: Composable components for flexibility
Configuration System¤
All models use a unified configuration class:
from artifex.generative_models.core.configuration import ModelConfig
config = ModelConfig(
# Required fields
model_type="vae", # Type of model
input_shape=(28, 28, 1), # Input dimensions
# Architecture
latent_dim=128, # Latent space size
encoder_features=[64, 128], # Encoder layer sizes
decoder_features=[128, 64], # Decoder layer sizes
# Model-specific parameters
parameters={
"beta": 1.0, # VAE-specific: β-VAE weight
"kl_weight": 1.0, # KL divergence weight
"reconstruction_loss": "mse" # Reconstruction loss type
},
# Optional metadata
metadata={
"experiment_id": "vae_001",
"dataset": "mnist"
}
)
Benefits:
- Type-safe with Pydantic validation
- Serializable (save/load configurations)
- Versioned for reproducibility
- Extensible for custom models
Device Management¤
Artifex automatically handles GPU/CPU/TPU:
from artifex.generative_models.core.device_manager import get_device_manager
import jax
# Automatic device detection
manager = get_device_manager()
info = manager.get_device_info()
print(f"Using: {info['backend']}") # gpu, cpu, or tpu
print(f"Device count: {len(jax.devices())}")
# Explicit configuration
from artifex.generative_models.core.device_manager import configure_for_generative_models, MemoryStrategy
manager = configure_for_generative_models(
memory_strategy=MemoryStrategy.BALANCED, # 75% GPU memory
enable_mixed_precision=True # BF16/FP16
)
Protocol System¤
Models implement the GenerativeModel protocol:
from typing import Protocol
from flax import nnx
import jax
class GenerativeModel(Protocol):
"""Base protocol for all generative models."""
def __call__(self, x: jax.Array, *, rngs: nnx.Rngs = None) -> dict:
"""Forward pass."""
...
def generate(self, *, n_samples: int, rngs: nnx.Rngs = None) -> jax.Array:
"""Generate samples from the model."""
...
def loss_fn(self, batch: dict, outputs: dict) -> dict:
"""Compute loss."""
...
This enables:
- Type checking at development time
- Generic training loops that work with any model
- Consistent interfaces across model types
JAX and Flax NNX Basics¤
Artifex is built on JAX and Flax NNX. Here are the key concepts:
JAX: Functional Programming¤
JAX provides functional transformations:
import jax
import jax.numpy as jnp
# JIT compilation for speed
@jax.jit
def fast_function(x):
return jnp.sum(x ** 2)
# Automatic differentiation
def loss_fn(params, x):
return jnp.sum((params['w'] * x) ** 2)
grad_fn = jax.grad(loss_fn) # Get gradient function
gradients = grad_fn(params, x) # Compute gradients
# Vectorization
batch_fn = jax.vmap(fast_function) # Apply to batches
Flax NNX: Object-Oriented Neural Networks¤
Flax NNX provides a Pythonic API for neural networks:
from flax import nnx
class MyModel(nnx.Module):
def __init__(self, features: int, *, rngs: nnx.Rngs):
super().__init__()
self.dense1 = nnx.Linear(784, features, rngs=rngs)
self.dense2 = nnx.Linear(features, 10, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
x = self.dense1(x)
x = nnx.relu(x)
x = self.dense2(x)
return x
# Create model
rngs = nnx.Rngs(0)
model = MyModel(features=128, rngs=rngs)
# Use model
x = jax.random.normal(jax.random.PRNGKey(0), (32, 784))
y = model(x)
Random Number Generation¤
JAX requires explicit RNG management:
from flax import nnx
# Create RNG
rngs = nnx.Rngs(seed=42)
# Use for model initialization
model = MyModel(rngs=rngs)
# Use for sampling
samples = jax.random.normal(rngs.sample(), (10, 784))
Multi-Modal Support¤
Artifex natively supports multiple data modalities:
Image¤
from artifex.data.datasets.image import CIFAR10Dataset
dataset = CIFAR10Dataset(root='./data', train=True)
Text¤
from artifex.data.datasets.text import WikipediaDataset
dataset = WikipediaDataset(
tokenizer='bpe',
max_length=512
)
Audio¤
from artifex.data.datasets.audio import LibriSpeechDataset
dataset = LibriSpeechDataset(
root='./data',
sample_rate=16000
)
Protein¤
from artifex.data.protein_dataset import ProteinDataset
dataset = ProteinDataset(
pdb_dir='./data/pdb',
with_constraints=True
)
Next Steps¤
Now that you understand the core concepts:
-
Build Your First Model
Hands-on tutorial to build a complete VAE from scratch
-
Explore Model Guides
Deep dives into each model type with examples
-
Check API Reference
Complete API documentation for all components
-
Learn Training
Training workflows, optimization, and distributed training
Further Reading¤
Generative Models¤
- VAE: Kingma & Welling (2013) - Auto-Encoding Variational Bayes
- GAN: Goodfellow et al. (2014) - Generative Adversarial Networks
- Diffusion: Ho et al. (2020) - Denoising Diffusion Probabilistic Models
- Flow: Dinh et al. (2016) - Density Estimation using Real NVP
JAX and Flax¤
Last Updated: 2025-10-13