Skip to content

Adapters¤

Module: generative_models.core.adapters

Source: generative_models/core/adapters.py

Overview¤

Model adapter classes for different architectures.

This module provides standardized interfaces for scaling different model architectures with consistent APIs and optimization strategies.

All implementations follow JAX/Flax NNX best practices and provide hardware-aware optimization for different model types.

Classes¤

DiffusionAdapter¤

class DiffusionAdapter

EnergyAdapter¤

class EnergyAdapter

ModelAdapter¤

class ModelAdapter

ModelSpecs¤

class ModelSpecs

TransformerAdapter¤

class TransformerAdapter

Functions¤

init¤

def __init__()

init¤

def __init__()

init¤

def __init__()

init¤

def __init__()

apply_sharding¤

def apply_sharding()

apply_sharding¤

def apply_sharding()

apply_sharding¤

def apply_sharding()

apply_sharding¤

def apply_sharding()

create_diffusion_adapter¤

def create_diffusion_adapter()

create_energy_adapter¤

def create_energy_adapter()

create_transformer_adapter¤

def create_transformer_adapter()

estimate_memory_usage¤

def estimate_memory_usage()

estimate_memory_usage¤

def estimate_memory_usage()

estimate_memory_usage¤

def estimate_memory_usage()

estimate_memory_usage¤

def estimate_memory_usage()

get_model_specs¤

def get_model_specs()

get_model_specs¤

def get_model_specs()

get_model_specs¤

def get_model_specs()

get_model_specs¤

def get_model_specs()

get_optimal_batch_size¤

def get_optimal_batch_size()

get_optimal_batch_size¤

def get_optimal_batch_size()

get_optimal_batch_size¤

def get_optimal_batch_size()

get_optimal_batch_size¤

def get_optimal_batch_size()

get_performance_characteristics¤

def get_performance_characteristics()

Module Statistics¤

  • Classes: 5
  • Functions: 24
  • Imports: 6