Skip to content

Model Factory¤

The factory module provides a centralized, type-safe interface for creating generative models in Artifex. It uses dataclass-based configurations to determine model type automatically, eliminating the need for string-based model class specifications.

Overview¤

  • Unified Interface


    Single create_model() function for all model types

  • Type-Safe


    Dataclass configs with automatic validation

  • Extensible


    Register custom builders for new model types

  • Modality Support


    Optional modality adaptation for domain-specific models

Quick Start¤

Basic Model Creation¤

from artifex.generative_models.factory import create_model
from artifex.generative_models.core.configuration import (
    VAEConfig,
    EncoderConfig,
    DecoderConfig,
)
from flax import nnx

# Create configuration
encoder = EncoderConfig(
    name="encoder",
    input_shape=(28, 28, 1),
    latent_dim=32,
    hidden_dims=(256, 128),
    activation="relu",
)
decoder = DecoderConfig(
    name="decoder",
    output_shape=(28, 28, 1),
    latent_dim=32,
    hidden_dims=(128, 256),
    activation="relu",
)
config = VAEConfig(
    name="my_vae",
    encoder=encoder,
    decoder=decoder,
    kl_weight=1.0,
)

# Create model - type is inferred from config
rngs = nnx.Rngs(params=42, dropout=43, sample=44)
model = create_model(config, rngs=rngs)

Model Type Inference¤

The factory automatically infers model type from the configuration class:

Config Class Model Type Created Model
VAEConfig, BetaVAEConfig, ConditionalVAEConfig, VQVAEConfig vae VAE variants
GANConfig, DCGANConfig, WGANConfig, LSGANConfig gan GAN variants
DiffusionConfig, DDPMConfig, ScoreDiffusionConfig diffusion Diffusion models
EBMConfig, DeepEBMConfig ebm Energy-based models
FlowConfig flow Normalizing flows
AutoregressiveConfig, TransformerConfig, PixelCNNConfig, WaveNetConfig autoregressive Autoregressive models
GeometricConfig, PointCloudConfig, MeshConfig, VoxelConfig, GraphConfig geometric Geometric models

API Reference¤

create_model¤

The main function for model creation.

def create_model(
    config: DataclassConfig,
    *,
    modality: str | None = None,
    rngs: nnx.Rngs,
    **kwargs,
) -> Any:
    """Create a model from configuration.

    Args:
        config: Dataclass model configuration (DDPMConfig, VAEConfig, etc.)
        modality: Optional modality for adaptation ('image', 'text', 'audio', etc.)
        rngs: Random number generators
        **kwargs: Additional arguments passed to the builder

    Returns:
        Created model instance

    Raises:
        TypeError: If config is not a supported dataclass config
        ValueError: If builder not found for model type
    """

Example:

from artifex.generative_models.factory import create_model
from artifex.generative_models.core.configuration import DDPMConfig, UNetBackboneConfig, NoiseScheduleConfig

# Create diffusion model config
backbone = UNetBackboneConfig(
    name="unet",
    in_channels=3,
    out_channels=3,
    base_channels=64,
    channel_mults=(1, 2, 4),
)
noise_schedule = NoiseScheduleConfig(
    name="schedule",
    schedule_type="linear",
    num_timesteps=1000,
    beta_start=1e-4,
    beta_end=2e-2,
)
config = DDPMConfig(
    name="ddpm",
    input_shape=(3, 32, 32),
    backbone=backbone,
    noise_schedule=noise_schedule,
)

# Create model
model = create_model(config, rngs=rngs)

create_model_with_extensions¤

Create a model with extensions for enhanced functionality.

def create_model_with_extensions(
    config: DataclassConfig,
    *,
    extensions_config: dict[str, ExtensionConfig] | None = None,
    modality: str | None = None,
    rngs: nnx.Rngs,
    **kwargs,
) -> tuple[Any, dict[str, ModelExtension]]:
    """Create a model and its extensions from configuration.

    Returns:
        Tuple of (model, extensions_dict)
    """

Example:

from artifex.generative_models.factory import create_model_with_extensions

# Create model with extensions
model, extensions = create_model_with_extensions(
    config,
    extensions_config={
        "augmentation": augmentation_config,
        "regularization": reg_config,
    },
    rngs=rngs,
)

ModelFactory¤

The underlying factory class for advanced usage.

class ModelFactory:
    """Centralized factory for all generative models."""

    def __init__(self):
        """Initialize with default builders."""

    def create(
        self,
        config: DataclassConfig,
        *,
        modality: str | None = None,
        rngs: nnx.Rngs,
        **kwargs,
    ) -> Any:
        """Create a model from dataclass configuration."""

Builders¤

Each model family has a dedicated builder that handles model instantiation:

VAE Builder¤

Creates VAE variants based on configuration type:

  • VAEConfigVAE
  • BetaVAEConfigBetaVAE
  • ConditionalVAEConfigConditionalVAE
  • VQVAEConfigVQVAE

VAE Builder Reference

GAN Builder¤

Creates GAN variants:

  • GANConfigGAN
  • DCGANConfigDCGAN
  • WGANConfigWGAN
  • LSGANConfigLSGAN

GAN Builder Reference

Diffusion Builder¤

Creates diffusion models:

  • DDPMConfigDDPMModel
  • ScoreDiffusionConfigScoreDiffusionModel
  • DiffusionConfigDiffusionModel

Diffusion Builder Reference

Flow Builder¤

Creates normalizing flows:

  • FlowConfigNormalizingFlow

Flow Builder Reference

EBM Builder¤

Creates energy-based models:

  • EBMConfigEBM
  • DeepEBMConfigDeepEBM

EBM Builder Reference

Autoregressive Builder¤

Creates autoregressive models:

  • TransformerConfigTransformer
  • PixelCNNConfigPixelCNN
  • WaveNetConfigWaveNet

Autoregressive Builder Reference

Geometric Builder¤

Creates geometric models:

  • PointCloudConfigPointCloudModel
  • MeshConfigMeshModel
  • VoxelConfigVoxelModel
  • GraphConfigGraphModel

Geometric Builder Reference

Registry¤

The model type registry manages builder registration:

from artifex.generative_models.factory.registry import ModelTypeRegistry

# Create custom registry
registry = ModelTypeRegistry()

# Register custom builder
registry.register("custom_type", CustomBuilder())

# Get builder
builder = registry.get_builder("custom_type")

Registry Reference

Modality Adaptation¤

The factory supports optional modality adaptation for domain-specific models:

# Create image-adapted model
model = create_model(config, modality="image", rngs=rngs)

# Create text-adapted model
model = create_model(config, modality="text", rngs=rngs)

# Create audio-adapted model
model = create_model(config, modality="audio", rngs=rngs)

Available Modalities:

  • image: Convolutional layers, FID/IS metrics
  • text: Tokenization, perplexity metrics
  • audio: Spectrograms, MFCCs
  • protein: Structure prediction, sequence modeling
  • geometric: Point clouds, meshes

Custom Builders¤

Create custom builders for new model types:

from artifex.generative_models.factory.registry import ModelBuilder
from flax import nnx

class CustomBuilder(ModelBuilder):
    """Builder for custom model type."""

    def build(self, config, *, rngs: nnx.Rngs, **kwargs):
        """Build the model from configuration."""
        return CustomModel(config, rngs=rngs, **kwargs)

# Register with factory
from artifex.generative_models.factory import ModelFactory

factory = ModelFactory()
factory.registry.register("custom", CustomBuilder())

Best Practices¤

DO¤

  • ✅ Use dataclass configs instead of dictionaries
  • ✅ Validate configs before passing to factory
  • ✅ Use type hints for better IDE support
  • ✅ Pass all required RNG streams to nnx.Rngs

DON'T¤

  • ❌ Pass dictionary configs (will raise TypeError)
  • ❌ Use string-based model class specifications
  • ❌ Forget to provide rngs parameter

Error Handling¤

The factory provides clear error messages:

# TypeError: Dictionary configs not supported
create_model({"model_class": "vae"}, rngs=rngs)
# Raises: TypeError: Expected dataclass config, got dict.

# TypeError: Unknown config type
create_model(UnknownConfig(), rngs=rngs)
# Raises: TypeError: Unknown config type: UnknownConfig

# ValueError: Builder not found
# (Only possible with custom registries)

Module Reference¤

Module Description
core Core factory implementation and create_model function
registry Model type registry and builder base class
vae VAE model builder
gan GAN model builder
diffusion Diffusion model builder
flow Normalizing flow builder
ebm Energy-based model builder
autoregressive Autoregressive model builder
geometric Geometric model builder