Model Factory¤
The factory module provides a centralized, type-safe interface for creating generative models in Artifex. It uses dataclass-based configurations to determine model type automatically, eliminating the need for string-based model class specifications.
Overview¤
-
Unified Interface
Single
create_model()function for all model types -
Type-Safe
Dataclass configs with automatic validation
-
Extensible
Register custom builders for new model types
-
Modality Support
Optional modality adaptation for domain-specific models
Quick Start¤
Basic Model Creation¤
from artifex.generative_models.factory import create_model
from artifex.generative_models.core.configuration import (
VAEConfig,
EncoderConfig,
DecoderConfig,
)
from flax import nnx
# Create configuration
encoder = EncoderConfig(
name="encoder",
input_shape=(28, 28, 1),
latent_dim=32,
hidden_dims=(256, 128),
activation="relu",
)
decoder = DecoderConfig(
name="decoder",
output_shape=(28, 28, 1),
latent_dim=32,
hidden_dims=(128, 256),
activation="relu",
)
config = VAEConfig(
name="my_vae",
encoder=encoder,
decoder=decoder,
kl_weight=1.0,
)
# Create model - type is inferred from config
rngs = nnx.Rngs(params=42, dropout=43, sample=44)
model = create_model(config, rngs=rngs)
Model Type Inference¤
The factory automatically infers model type from the configuration class:
| Config Class | Model Type | Created Model |
|---|---|---|
VAEConfig, BetaVAEConfig, ConditionalVAEConfig, VQVAEConfig |
vae |
VAE variants |
GANConfig, DCGANConfig, WGANConfig, LSGANConfig |
gan |
GAN variants |
DiffusionConfig, DDPMConfig, ScoreDiffusionConfig |
diffusion |
Diffusion models |
EBMConfig, DeepEBMConfig |
ebm |
Energy-based models |
FlowConfig |
flow |
Normalizing flows |
AutoregressiveConfig, TransformerConfig, PixelCNNConfig, WaveNetConfig |
autoregressive |
Autoregressive models |
GeometricConfig, PointCloudConfig, MeshConfig, VoxelConfig, GraphConfig |
geometric |
Geometric models |
API Reference¤
create_model¤
The main function for model creation.
def create_model(
config: DataclassConfig,
*,
modality: str | None = None,
rngs: nnx.Rngs,
**kwargs,
) -> Any:
"""Create a model from configuration.
Args:
config: Dataclass model configuration (DDPMConfig, VAEConfig, etc.)
modality: Optional modality for adaptation ('image', 'text', 'audio', etc.)
rngs: Random number generators
**kwargs: Additional arguments passed to the builder
Returns:
Created model instance
Raises:
TypeError: If config is not a supported dataclass config
ValueError: If builder not found for model type
"""
Example:
from artifex.generative_models.factory import create_model
from artifex.generative_models.core.configuration import DDPMConfig, UNetBackboneConfig, NoiseScheduleConfig
# Create diffusion model config
backbone = UNetBackboneConfig(
name="unet",
in_channels=3,
out_channels=3,
base_channels=64,
channel_mults=(1, 2, 4),
)
noise_schedule = NoiseScheduleConfig(
name="schedule",
schedule_type="linear",
num_timesteps=1000,
beta_start=1e-4,
beta_end=2e-2,
)
config = DDPMConfig(
name="ddpm",
input_shape=(3, 32, 32),
backbone=backbone,
noise_schedule=noise_schedule,
)
# Create model
model = create_model(config, rngs=rngs)
create_model_with_extensions¤
Create a model with extensions for enhanced functionality.
def create_model_with_extensions(
config: DataclassConfig,
*,
extensions_config: dict[str, ExtensionConfig] | None = None,
modality: str | None = None,
rngs: nnx.Rngs,
**kwargs,
) -> tuple[Any, dict[str, ModelExtension]]:
"""Create a model and its extensions from configuration.
Returns:
Tuple of (model, extensions_dict)
"""
Example:
from artifex.generative_models.factory import create_model_with_extensions
# Create model with extensions
model, extensions = create_model_with_extensions(
config,
extensions_config={
"augmentation": augmentation_config,
"regularization": reg_config,
},
rngs=rngs,
)
ModelFactory¤
The underlying factory class for advanced usage.
class ModelFactory:
"""Centralized factory for all generative models."""
def __init__(self):
"""Initialize with default builders."""
def create(
self,
config: DataclassConfig,
*,
modality: str | None = None,
rngs: nnx.Rngs,
**kwargs,
) -> Any:
"""Create a model from dataclass configuration."""
Builders¤
Each model family has a dedicated builder that handles model instantiation:
VAE Builder¤
Creates VAE variants based on configuration type:
VAEConfig→VAEBetaVAEConfig→BetaVAEConditionalVAEConfig→ConditionalVAEVQVAEConfig→VQVAE
GAN Builder¤
Creates GAN variants:
GANConfig→GANDCGANConfig→DCGANWGANConfig→WGANLSGANConfig→LSGAN
Diffusion Builder¤
Creates diffusion models:
DDPMConfig→DDPMModelScoreDiffusionConfig→ScoreDiffusionModelDiffusionConfig→DiffusionModel
Flow Builder¤
Creates normalizing flows:
FlowConfig→NormalizingFlow
EBM Builder¤
Creates energy-based models:
EBMConfig→EBMDeepEBMConfig→DeepEBM
Autoregressive Builder¤
Creates autoregressive models:
TransformerConfig→TransformerPixelCNNConfig→PixelCNNWaveNetConfig→WaveNet
Autoregressive Builder Reference
Geometric Builder¤
Creates geometric models:
PointCloudConfig→PointCloudModelMeshConfig→MeshModelVoxelConfig→VoxelModelGraphConfig→GraphModel
Registry¤
The model type registry manages builder registration:
from artifex.generative_models.factory.registry import ModelTypeRegistry
# Create custom registry
registry = ModelTypeRegistry()
# Register custom builder
registry.register("custom_type", CustomBuilder())
# Get builder
builder = registry.get_builder("custom_type")
Modality Adaptation¤
The factory supports optional modality adaptation for domain-specific models:
# Create image-adapted model
model = create_model(config, modality="image", rngs=rngs)
# Create text-adapted model
model = create_model(config, modality="text", rngs=rngs)
# Create audio-adapted model
model = create_model(config, modality="audio", rngs=rngs)
Available Modalities:
image: Convolutional layers, FID/IS metricstext: Tokenization, perplexity metricsaudio: Spectrograms, MFCCsprotein: Structure prediction, sequence modelinggeometric: Point clouds, meshes
Custom Builders¤
Create custom builders for new model types:
from artifex.generative_models.factory.registry import ModelBuilder
from flax import nnx
class CustomBuilder(ModelBuilder):
"""Builder for custom model type."""
def build(self, config, *, rngs: nnx.Rngs, **kwargs):
"""Build the model from configuration."""
return CustomModel(config, rngs=rngs, **kwargs)
# Register with factory
from artifex.generative_models.factory import ModelFactory
factory = ModelFactory()
factory.registry.register("custom", CustomBuilder())
Best Practices¤
DO¤
- ✅ Use dataclass configs instead of dictionaries
- ✅ Validate configs before passing to factory
- ✅ Use type hints for better IDE support
- ✅ Pass all required RNG streams to
nnx.Rngs
DON'T¤
- ❌ Pass dictionary configs (will raise
TypeError) - ❌ Use string-based model class specifications
- ❌ Forget to provide
rngsparameter
Error Handling¤
The factory provides clear error messages:
# TypeError: Dictionary configs not supported
create_model({"model_class": "vae"}, rngs=rngs)
# Raises: TypeError: Expected dataclass config, got dict.
# TypeError: Unknown config type
create_model(UnknownConfig(), rngs=rngs)
# Raises: TypeError: Unknown config type: UnknownConfig
# ValueError: Builder not found
# (Only possible with custom registries)
Module Reference¤
| Module | Description |
|---|---|
| core | Core factory implementation and create_model function |
| registry | Model type registry and builder base class |
| vae | VAE model builder |
| gan | GAN model builder |
| diffusion | Diffusion model builder |
| flow | Normalizing flow builder |
| ebm | Energy-based model builder |
| autoregressive | Autoregressive model builder |
| geometric | Geometric model builder |
Related Documentation¤
- Configuration System - Understanding dataclass configs
- Model Gallery - Available model implementations
- Factory Guide - Detailed factory usage guide