Extensions¤
Extensions in Artifex provide domain-specific functionality that enhances base generative models with specialized knowledge and constraints. This modular system allows you to add protein modeling, NLP preprocessing, audio processing, and vision augmentation capabilities to any model.
Overview¤
-
Protein Extensions
Add protein-specific constraints, backbone geometry, and amino acid features to geometric models
-
NLP Extensions
Tokenization, embeddings, and text preprocessing for language models
-
Audio Extensions
Spectral processing and temporal features for audio generation
-
Vision Extensions
Image augmentation and preprocessing for visual models
Quick Start¤
Extensions integrate seamlessly with Artifex models through the extension system:
import jax
from flax import nnx
from artifex.generative_models.core.configuration import (
ProteinExtensionConfig,
ProteinMixinConfig,
)
from artifex.generative_models.extensions.protein import (
BondLengthExtension,
BondAngleExtension,
ProteinMixinExtension,
)
from artifex.generative_models.models.geometric.point_cloud import PointCloudModel
# Configure extensions using frozen dataclass configs
mixin_config = ProteinMixinConfig(
name="protein_mixin",
weight=1.0,
enabled=True,
embedding_dim=64,
num_aa_types=20,
)
# Create extension instance
key = jax.random.key(42)
protein_mixin = ProteinMixinExtension(
config=mixin_config,
rngs=nnx.Rngs(params=key),
)
# Wrap in nnx.Dict for NNX compatibility
extensions = nnx.Dict({"protein_mixin": protein_mixin})
# Create model with extensions
model = PointCloudModel(model_config, extensions=extensions, rngs=nnx.Rngs(params=key))
Protein Extensions¤
Protein extensions add domain knowledge about molecular structure to geometric models, enabling physically realistic protein generation.
Available Extensions¤
| Extension | Description | Key Features |
|---|---|---|
| ProteinMixinExtension | Amino acid integration | 20 AA type embeddings, residue features |
| ProteinBackboneConstraint | Backbone geometry | N, CA, C, O atom indices, geometric constraints |
| BondLengthExtension | Bond distance monitoring | Violation detection, loss contribution |
| BondAngleExtension | Bond angle monitoring | Peptide bond angles, backbone geometry |
Usage Example¤
import jax
from flax import nnx
from artifex.generative_models.core.configuration import (
ProteinExtensionConfig,
ProteinMixinConfig,
)
from artifex.generative_models.extensions.protein import (
BondLengthExtension,
BondAngleExtension,
ProteinMixinExtension,
)
from artifex.generative_models.extensions.protein.constraints import (
ProteinBackboneConstraint,
)
key = jax.random.key(42)
# Create multiple protein extensions
extensions_dict = {}
# Amino acid type integration with frozen dataclass config
extensions_dict["protein_mixin"] = ProteinMixinExtension(
config=ProteinMixinConfig(
name="protein_mixin",
weight=1.0,
enabled=True,
embedding_dim=64,
num_aa_types=20,
),
rngs=nnx.Rngs(params=key),
)
# Backbone constraints with explicit fields
extensions_dict["backbone"] = ProteinBackboneConstraint(
config=ProteinExtensionConfig(
name="backbone",
weight=1.0,
enabled=True,
bond_length_weight=1.0,
bond_angle_weight=0.5,
),
rngs=nnx.Rngs(params=key),
)
# Bond length monitoring
extensions_dict["bond_length"] = BondLengthExtension(
config=ProteinExtensionConfig(
name="bond_length",
weight=1.0,
enabled=True,
bond_length_weight=1.0,
ideal_bond_lengths={"N-CA": 1.45, "CA-C": 1.52, "C-N": 1.33},
),
rngs=nnx.Rngs(params=key),
)
# Wrap for NNX
extensions = nnx.Dict(extensions_dict)
Documentation¤
- Backbone Extension - Backbone atom handling
- Constraints - Geometric constraints
- Mixin Extension - Amino acid integration
- Utilities - Protein utility functions
NLP Extensions¤
NLP extensions provide text processing capabilities for language models and multimodal systems.
Available Extensions¤
| Extension | Description | Key Features |
|---|---|---|
| Tokenization | Text tokenization | BPE, SentencePiece, character-level |
| Embeddings | Token embeddings | Positional encoding, learned embeddings |
Documentation¤
- Tokenization - Tokenization methods
- Embeddings - Embedding systems
Audio Extensions¤
Audio extensions add signal processing capabilities for audio generation models.
Available Extensions¤
| Extension | Description | Key Features |
|---|---|---|
| Spectral | Frequency analysis | STFT, mel-spectrograms, spectrogram inversion |
| Temporal | Time-domain features | Envelope extraction, onset detection |
Documentation¤
- Spectral Processing - Frequency domain operations
- Temporal Features - Time domain processing
Vision Extensions¤
Vision extensions provide image preprocessing and augmentation for visual models.
Available Extensions¤
| Extension | Description | Key Features |
|---|---|---|
| Augmentation | Data augmentation | Flips, rotations, color jitter, cutout |
Documentation¤
- Augmentation - Image augmentation methods
Extension Architecture¤
Configuration Classes¤
Extensions use frozen dataclass configurations from core.configuration:
from artifex.generative_models.core.configuration import (
ExtensionConfig, # Base extension config
ConstraintExtensionConfig, # For constraint extensions
ProteinExtensionConfig, # Protein-specific constraints
ProteinMixinConfig, # Protein amino acid features
ChemicalConstraintConfig, # Chemical/molecular constraints
ImageAugmentationConfig, # Vision augmentation
AudioSpectralConfig, # Audio spectral processing
TextEmbeddingConfig, # NLP embeddings
)
# Base ExtensionConfig for simple extensions
config = ExtensionConfig(
name="my_extension", # Unique identifier
weight=1.0, # Loss contribution weight
enabled=True, # Enable/disable toggle
)
# Domain-specific configs have explicit fields (no extensions dict)
protein_config = ProteinExtensionConfig(
name="backbone",
weight=1.0,
enabled=True,
bond_length_weight=1.0, # Explicit field, not in extensions dict
bond_angle_weight=0.5,
ideal_bond_lengths={"N-CA": 1.45, "CA-C": 1.52},
)
Extension Registry¤
Extensions can be registered and discovered through the registry:
from artifex.generative_models.extensions.registry import (
register_extension,
get_extension,
list_extensions,
)
# List available extensions
available = list_extensions()
print(f"Available extensions: {available}")
# Get extension by name
ExtensionClass = get_extension("protein_mixin")
Documentation¤
- Extensions Base - Base extension classes
- Registry - Extension registration system
- Features - Feature extraction utilities
Creating Custom Extensions¤
You can create custom extensions by inheriting from the base extension class:
import dataclasses
import jax
import jax.numpy as jnp
from flax import nnx
from artifex.generative_models.core.configuration import ExtensionConfig
from artifex.generative_models.extensions.base import ModelExtension
# Define a custom frozen dataclass config for your extension
@dataclasses.dataclass(frozen=True)
class MyExtensionConfig(ExtensionConfig):
"""Custom extension configuration."""
my_param: float = 1.0
another_param: int = 10
class MyCustomExtension(ModelExtension):
"""Custom extension for domain-specific processing."""
def __init__(
self,
config: MyExtensionConfig | ExtensionConfig,
*,
rngs: nnx.Rngs,
):
super().__init__(config, rngs=rngs)
# Use explicit config fields (frozen dataclass pattern)
if isinstance(config, MyExtensionConfig):
self.my_param = config.my_param
self.another_param = config.another_param
else:
# Defaults for base ExtensionConfig
self.my_param = 1.0
self.another_param = 10
def __call__(self, inputs, model_outputs, **kwargs) -> dict:
"""Process inputs and model outputs.
Args:
inputs: Input batch dictionary
model_outputs: Model predictions
**kwargs: Additional arguments
Returns:
Extension outputs dictionary
"""
if not self.enabled:
return {"extension_type": "my_custom"}
# Implement extension logic
result = self._process(inputs, model_outputs)
return {"my_output": result, "extension_type": "my_custom"}
def loss_fn(self, batch: dict, model_outputs, **kwargs) -> jax.Array:
"""Compute extension-specific loss.
Args:
batch: Input batch
model_outputs: Model outputs
Returns:
Loss value (scalar JAX array)
"""
if not self.enabled:
return jnp.array(0.0)
# Implement loss computation using pure JAX operations
return self._compute_loss(batch, model_outputs)
Best Practices¤
DO
- Use frozen dataclass configs from
core.configuration - Use domain-specific configs (e.g.,
ProteinExtensionConfig) with explicit fields - Wrap extensions in
nnx.Dictfor NNX compatibility - Set appropriate weights for multi-extension setups
- Disable unused extensions for efficiency
- Use pure JAX operations in
loss_fnfor JIT compatibility
DON'T
- Don't use
extensions={}dict pattern (old Pydantic style) - Don't use raw dictionaries instead of
nnx.Dict - Don't forget to pass
rngsto extension constructors - Don't use conflicting extension names
- Don't enable extensions without proper configuration
- Don't mutate RNGs inside traced functions (JIT/grad)
Summary¤
Extensions provide a modular way to add domain-specific functionality:
- Protein: Physical constraints and amino acid features
- NLP: Tokenization and text embeddings
- Audio: Spectral and temporal processing
- Vision: Image augmentation
All extensions follow consistent patterns for configuration, registration, and integration with Artifex models.