Skip to content

Diffusion Models API Reference¤

Complete API documentation for all diffusion model classes in Workshop.

Base Classes¤

DiffusionModel¤

workshop.generative_models.models.diffusion.base.DiffusionModel ¤

DiffusionModel(config: DiffusionConfig, *, rngs: Rngs)

Bases: GenerativeModel

Base class for diffusion models.

This implements a general diffusion model that can support various diffusion processes like DDPM (Denoising Diffusion Probabilistic Models) and DDIM (Denoising Diffusion Implicit Models).

Uses the nested DiffusionConfig architecture with: - backbone: BackboneConfig (polymorphic) for the denoising network - noise_schedule: NoiseScheduleConfig for the diffusion schedule

Backbone type is determined by config.backbone.backbone_type discriminator. Supported backbones: UNet, DiT, U-ViT, UNet2DCondition.

Attributes:

Name Type Description
config

DiffusionConfig for the model

backbone

The denoising network (created from config.backbone)

noise_schedule NoiseSchedule

NoiseSchedule instance for diffusion process

Parameters:

Name Type Description Default
config DiffusionConfig

DiffusionConfig with nested backbone and noise_schedule configs. The backbone field accepts any BackboneConfig type (UNetBackboneConfig, DiTBackboneConfig, etc.) and the appropriate backbone is created based on the backbone_type discriminator.

required
rngs Rngs

Random number generators

required

config instance-attribute ¤

config = config

backbone instance-attribute ¤

backbone = create_backbone(backbone, rngs=rngs)

noise_schedule instance-attribute ¤

noise_schedule: NoiseSchedule = create_noise_schedule(noise_schedule)

betas instance-attribute ¤

betas = betas

alphas instance-attribute ¤

alphas = alphas

alphas_cumprod instance-attribute ¤

alphas_cumprod = alphas_cumprod

alphas_cumprod_prev instance-attribute ¤

alphas_cumprod_prev = alphas_cumprod_prev

sqrt_alphas_cumprod instance-attribute ¤

sqrt_alphas_cumprod = sqrt_alphas_cumprod

sqrt_one_minus_alphas_cumprod instance-attribute ¤

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod

log_one_minus_alphas_cumprod instance-attribute ¤

log_one_minus_alphas_cumprod = log_one_minus_alphas_cumprod

sqrt_recip_alphas_cumprod instance-attribute ¤

sqrt_recip_alphas_cumprod = sqrt_recip_alphas_cumprod

sqrt_recipm1_alphas_cumprod instance-attribute ¤

sqrt_recipm1_alphas_cumprod = sqrt_recipm1_alphas_cumprod

posterior_variance instance-attribute ¤

posterior_variance = posterior_variance

posterior_log_variance_clipped instance-attribute ¤

posterior_log_variance_clipped = posterior_log_variance_clipped

posterior_mean_coef1 instance-attribute ¤

posterior_mean_coef1 = posterior_mean_coef1

posterior_mean_coef2 instance-attribute ¤

posterior_mean_coef2 = posterior_mean_coef2

q_sample ¤

q_sample(x_start: Array, t: Array, noise: Array | None = None) -> Array

Sample from the forward diffusion process q(x_t | x_0).

Parameters:

Name Type Description Default
x_start Array

Starting clean data (x_0)

required
t Array

Timesteps

required
noise Array | None

Optional pre-generated noise

None

Returns:

Type Description
Array

Noisy samples x_t

predict_start_from_noise ¤

predict_start_from_noise(x_t: Array, t: Array, noise: Array) -> Array

Predict x_0 from noise model output.

Parameters:

Name Type Description Default
x_t Array

Noisy input at timestep t

required
t Array

Timesteps

required
noise Array

Predicted noise

required

Returns:

Type Description
Array

Predicted x_0

q_posterior_mean_variance ¤

q_posterior_mean_variance(x_start: Array, x_t: Array, t: Array) -> tuple[Array, Array, Array]

Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0).

Parameters:

Name Type Description Default
x_start Array

Clean data (x_0)

required
x_t Array

Noisy data (x_t)

required
t Array

Timesteps

required

Returns:

Type Description
tuple[Array, Array, Array]

Tuple of (posterior_mean, posterior_variance, posterior_log_variance_clipped)

p_mean_variance ¤

p_mean_variance(model_output: Array, x_t: Array, t: Array, clip_denoised: bool = True) -> dict[str, Array]

Compute the model's predicted mean and variance for x_{t-1}.

Parameters:

Name Type Description Default
model_output Array

Predicted noise or x_0

required
x_t Array

Noisy input at timestep t

required
t Array

Timesteps

required
clip_denoised bool

Whether to clip the denoised signal to [-1, 1]

True

Returns:

Type Description
dict[str, Array]

dictionary with predicted mean and variance

p_sample ¤

p_sample(model_output: Array, x_t: Array, t: Array, clip_denoised: bool = True) -> Array

Sample from the denoising process p(x_{t-1} | x_t).

Parameters:

Name Type Description Default
model_output Array

Predicted noise

required
x_t Array

Noisy input at timestep t

required
t Array

Timesteps

required
clip_denoised bool

Whether to clip the denoised signal to [-1, 1]

True

Returns:

Type Description
Array

Denoised x_{t-1}

generate ¤

generate(n_samples: int = 1, *, shape: tuple[int, ...] | None = None, clip_denoised: bool = True) -> Array

Generate samples from random noise.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
shape tuple[int, ...] | None

Shape of samples to generate (excluding batch dimension)

None
clip_denoised bool

Whether to clip the denoised signal to [-1, 1]

True

Returns:

Type Description
Array

Generated samples

loss_fn ¤

loss_fn(batch: Any, model_outputs: dict[str, Any]) -> dict[str, Any]

Compute loss for training.

Parameters:

Name Type Description Default
batch Any

Input batch (should contain 'x' key with data)

required
model_outputs dict[str, Any]

Model outputs from forward pass

required

Returns:

Type Description
dict[str, Any]

Dictionary containing loss and metrics

Base class for all diffusion models, implementing the core diffusion process.

Purpose: Provides the foundational diffusion framework including forward diffusion (adding noise), reverse diffusion (denoising), and noise scheduling.

Initialization¤

DiffusionModel(
    config: ModelConfiguration,
    backbone_fn: Callable,
    *,
    rngs: nnx.Rngs
)

Parameters:

Parameter Type Description
config ModelConfiguration Model configuration with input dimensions and parameters
backbone_fn Callable Function that creates the backbone network (e.g., U-Net)
rngs nnx.Rngs Random number generators for initialization

Configuration Parameters:

Parameter Type Default Description
noise_steps int 1000 Number of diffusion timesteps
beta_start float 1e-4 Initial noise variance
beta_end float 0.02 Final noise variance

Methods¤

__call__(x, timesteps, *, rngs=None, training=False, **kwargs)¤

Forward pass through the diffusion model.

Parameters:

  • x (jax.Array): Input data (batch, *input_dim)
  • timesteps (jax.Array): Timestep indices (batch,)
  • rngs (nnx.Rngs | None): Random number generators
  • training (bool): Whether in training mode
  • **kwargs: Additional arguments passed to backbone

Returns:

  • dict[str, Any]: Dictionary containing "predicted_noise" and potentially other outputs

Example:

# Create model
model = DiffusionModel(config, backbone_fn, rngs=rngs)

# Forward pass
x = jax.random.normal(rngs.sample(), (4, 32, 32, 3))
t = jnp.array([100, 200, 300, 400])
outputs = model(x, t, rngs=rngs, training=True)

print(outputs["predicted_noise"].shape)  # (4, 32, 32, 3)
setup_noise_schedule()¤

Set up the noise schedule for the diffusion process.

Description:

Computes the noise schedule (betas) and derived quantities (alphas, alpha_cumprod, etc.) used throughout the diffusion process. Default implementation uses a linear schedule.

Computed Attributes:

  • betas: Noise variances at each timestep
  • alphas: \(\alpha_t = 1 - \beta_t\)
  • alphas_cumprod: \(\bar{\alpha}_t = \prod_{i=1}^t \alpha_i\)
  • sqrt_alphas_cumprod: \(\sqrt{\bar{\alpha}_t}\)
  • sqrt_one_minus_alphas_cumprod: \(\sqrt{1 - \bar{\alpha}_t}\)
  • posterior_variance: Variance of \(q(x_{t-1} | x_t, x_0)\)
q_sample(x_start, t, noise=None, *, rngs=None)¤

Sample from the forward diffusion process \(q(x_t | x_0)\).

Parameters:

  • x_start (jax.Array): Clean data \(x_0\)
  • t (jax.Array): Timesteps (batch,)
  • noise (jax.Array | None): Optional pre-generated noise
  • rngs (nnx.Rngs | None): Random number generators

Returns:

  • jax.Array: Noisy samples \(x_t\)

Mathematical Formula:

\[ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \]

Example:

# Add noise to clean images
x_clean = jax.random.normal(rngs.sample(), (8, 32, 32, 3))
t = jnp.array([100] * 8)

x_noisy = model.q_sample(x_clean, t, rngs=rngs)
print(f"Clean: {x_clean.shape}, Noisy: {x_noisy.shape}")
p_sample(model_output, x_t, t, *, rngs=None, clip_denoised=True)¤

Sample from the denoising process \(p(x_{t-1} | x_t)\).

Parameters:

  • model_output (jax.Array): Predicted noise from model
  • x_t (jax.Array): Noisy input at timestep \(t\)
  • t (jax.Array): Current timesteps
  • rngs (nnx.Rngs | None): Random number generators
  • clip_denoised (bool): Whether to clip to [-1, 1]

Returns:

  • jax.Array: Denoised sample \(x_{t-1}\)

Example:

# Single denoising step
x_t = noisy_sample
t = jnp.array([500])

# Get model prediction
outputs = model(x_t, t, rngs=rngs)
predicted_noise = outputs["predicted_noise"]

# Denoise one step
x_t_minus_1 = model.p_sample(predicted_noise, x_t, t, rngs=rngs)
generate(n_samples=1, *, rngs=None, shape=None, clip_denoised=True, **kwargs)¤

Generate samples from random noise.

Parameters:

  • n_samples (int): Number of samples to generate
  • rngs (nnx.Rngs | None): Random number generators
  • shape (tuple[int, ...] | None): Sample shape (excluding batch)
  • clip_denoised (bool): Whether to clip to [-1, 1]
  • **kwargs: Additional model arguments

Returns:

  • jax.Array: Generated samples (n_samples, *shape)

Example:

# Generate 16 samples
samples = model.generate(n_samples=16, rngs=rngs)
print(f"Generated: {samples.shape}")  # (16, 32, 32, 3)
predict_start_from_noise(x_t, t, noise)¤

Predict \(x_0\) from \(x_t\) and predicted noise.

Parameters:

  • x_t (jax.Array): Noisy sample at timestep \(t\)
  • t (jax.Array): Timesteps
  • noise (jax.Array): Predicted noise

Returns:

  • jax.Array: Predicted \(x_0\)

Mathematical Formula:

\[ x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} x_t - \frac{\sqrt{1 - \bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t}} \epsilon \]
loss_fn(batch, model_outputs, *, rngs=None, **kwargs)¤

Compute the diffusion loss.

Parameters:

  • batch (Any): Input batch (dict with 'x' key or array)
  • model_outputs (dict[str, Any]): Model predictions
  • rngs (nnx.Rngs | None): Random number generators
  • **kwargs: Additional arguments

Returns:

  • dict[str, Any]: Dictionary with 'loss' and metrics

Example:

# Training loop
@nnx.jit
def train_step(model, optimizer, batch, rngs):
    def loss_fn_wrapper(model):
        # Add noise
        t = jax.random.randint(rngs.timestep(), (batch.shape[0],), 0, 1000)
        noise = jax.random.normal(rngs.noise(), batch.shape)
        x_noisy = model.q_sample(batch, t, noise, rngs=rngs)

        # Predict
        outputs = model(x_noisy, t, training=True, rngs=rngs)

        # Compute loss
        loss_dict = model.loss_fn({"x": batch}, outputs, rngs=rngs)
        return loss_dict["loss"]

    loss, grads = nnx.value_and_grad(loss_fn_wrapper)(model)
    optimizer.update(grads)
    return {"loss": loss}

DDPM (Denoising Diffusion Probabilistic Models)¤

DDPMModel¤

workshop.generative_models.models.diffusion.ddpm.DDPMModel ¤

DDPMModel(config: DDPMConfig, *, rngs: Rngs)

Bases: DiffusionModel

DDPM (Denoising Diffusion Probabilistic Models) implementation.

This model implements the denoising diffusion probabilistic model as described in the DDPM paper by Ho et al.

Uses nested DDPMConfig with: - backbone: BackboneConfig (polymorphic) for the denoising network - noise_schedule: NoiseScheduleConfig for the diffusion schedule - loss_type: Loss function type (mse, l1, huber) - clip_denoised: Whether to clip denoised samples

Backbone type is determined by config.backbone.backbone_type discriminator.

Parameters:

Name Type Description Default
config DDPMConfig

DDPMConfig with nested backbone and noise_schedule configs. The backbone field accepts any BackboneConfig type and the appropriate backbone is created based on backbone_type.

required
rngs Rngs

Random number generators

required

loss_type instance-attribute ¤

loss_type = loss_type

clip_denoised instance-attribute ¤

clip_denoised = clip_denoised

input_dim instance-attribute ¤

input_dim = input_shape

in_channels instance-attribute ¤

in_channels = input_shape[0]

noise_steps instance-attribute ¤

noise_steps = num_timesteps

beta_start instance-attribute ¤

beta_start = beta_start

beta_end instance-attribute ¤

beta_end = beta_end

beta_schedule instance-attribute ¤

beta_schedule = schedule_type

forward_diffusion ¤

forward_diffusion(x: Array, t: Array) -> tuple[Array, Array]

Forward diffusion process q(x_t | x_0).

Parameters:

Name Type Description Default
x Array

Input data tensor

required
t Array

Timestep indices

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (noisy_x, noise)

denoise_step ¤

denoise_step(x_t: Array, t: Array, predicted_noise: Array, clip_denoised: bool = True) -> Array

Perform a single denoising step: x_{t-1} = f(x_t, t, noise).

Parameters:

Name Type Description Default
x_t Array

Noisy input at timestep t

required
t Array

Current timestep indices

required
predicted_noise Array

Predicted noise from the model

required
clip_denoised bool

Whether to clip values to [-1, 1]

True

Returns:

Type Description
Array

Denoised x_{t-1}

sample ¤

sample(n_samples_or_shape: int | tuple[int, ...], scheduler: str = 'ddpm', steps: int | None = None) -> Array

Sample from the diffusion model.

Parameters:

Name Type Description Default
n_samples_or_shape int | tuple[int, ...]

Number of samples or full shape including batch dimension

required
scheduler str

Sampling scheduler to use ('ddpm', 'ddim')

'ddpm'
steps int | None

Number of sampling steps (if None, use default)

None

Returns:

Type Description
Array

Generated samples

Standard DDPM implementation with support for both DDPM and DDIM sampling.

Purpose: Implements the foundational denoising diffusion probabilistic model with standard training and sampling procedures.

Initialization¤

DDPMModel(
    config: ModelConfiguration,
    *,
    rngs: nnx.Rngs,
    **kwargs
)

Parameters:

Parameter Type Description
config ModelConfiguration Model configuration
rngs nnx.Rngs Random number generators
**kwargs dict Additional arguments (e.g., backbone_fn)

Configuration Parameters:

Parameter Type Default Description
noise_steps int 1000 Number of diffusion timesteps
beta_start float 1e-4 Initial noise variance
beta_end float 0.02 Final noise variance
beta_schedule str "linear" Schedule type ("linear" or "cosine")

Methods¤

forward_diffusion(x, t, *, rngs=None)¤

Forward diffusion process \(q(x_t | x_0)\).

Parameters:

  • x (jax.Array): Clean input data
  • t (jax.Array): Timestep indices
  • rngs (nnx.Rngs | None): Random number generators

Returns:

  • tuple[jax.Array, jax.Array]: (noisy_x, noise) tuple

Example:

model = DDPMModel(config, rngs=rngs)

x_clean = jnp.ones((4, 32, 32, 3))
t = jnp.array([100, 200, 300, 400])

x_noisy, noise = model.forward_diffusion(x_clean, t, rngs=rngs)
denoise_step(x_t, t, predicted_noise, clip_denoised=True)¤

Single denoising step.

Parameters:

  • x_t (jax.Array): Noisy input at timestep \(t\)
  • t (jax.Array): Current timesteps
  • predicted_noise (jax.Array): Predicted noise
  • clip_denoised (bool): Whether to clip to [-1, 1]

Returns:

  • jax.Array: Denoised \(x_{t-1}\)
sample(n_samples_or_shape, scheduler="ddpm", steps=None, *, rngs=None, **kwargs)¤

Sample from the diffusion model.

Parameters:

  • n_samples_or_shape (int | tuple): Number of samples or full shape
  • scheduler (str): Sampling scheduler ("ddpm" or "ddim")
  • steps (int | None): Number of sampling steps
  • rngs (nnx.Rngs | None): Random number generators
  • **kwargs: Additional arguments

Returns:

  • jax.Array: Generated samples

Example:

# DDPM sampling (slow but high quality)
samples_ddpm = model.sample(16, scheduler="ddpm", rngs=rngs)

# DDIM sampling (fast)
samples_ddim = model.sample(16, scheduler="ddim", steps=50, rngs=rngs)

DDIM (Denoising Diffusion Implicit Models)¤

DDIMModel¤

workshop.generative_models.models.diffusion.ddim.DDIMModel ¤

DDIMModel(config: DDIMConfig, *, rngs: Rngs)

Bases: DDPMModel

DDIM (Denoising Diffusion Implicit Models) implementation.

This model implements deterministic sampling from diffusion models as described in the DDIM paper by Song et al. DDIM enables faster sampling with fewer steps while maintaining high quality.

Uses nested DDIMConfig with: - backbone: BackboneConfig (polymorphic) for the denoising network - noise_schedule: NoiseScheduleConfig for the diffusion schedule - eta: Stochasticity parameter (0=deterministic, 1=DDPM) - num_inference_steps: Number of sampling steps - skip_type: Timestep skip strategy

Parameters:

Name Type Description Default
config DDIMConfig

DDIMConfig with nested backbone and noise_schedule configs. The backbone field accepts any BackboneConfig type and the appropriate backbone is created based on backbone_type.

required
rngs Rngs

Random number generators

required

eta instance-attribute ¤

eta = eta

ddim_steps instance-attribute ¤

ddim_steps = num_inference_steps

skip_type instance-attribute ¤

skip_type = skip_type

get_ddim_timesteps ¤

get_ddim_timesteps(ddim_steps: int) -> Array

Get timesteps for DDIM sampling.

Parameters:

Name Type Description Default
ddim_steps int

Number of DDIM sampling steps

required

Returns:

Type Description
Array

Array of timesteps for DDIM sampling

ddim_step ¤

ddim_step(x_t: Array, t: Array, t_prev: Array, predicted_noise: Array, eta: float | None = None) -> Array

Perform a single DDIM sampling step.

Parameters:

Name Type Description Default
x_t Array

Current sample at timestep t

required
t Array

Current timestep

required
t_prev Array

Previous timestep

required
predicted_noise Array

Predicted noise from the model

required
eta float | None

DDIM interpolation parameter (0=deterministic, 1=DDPM)

None

Returns:

Type Description
Array

Sample at timestep t_prev

ddim_sample ¤

ddim_sample(n_samples: int, steps: int | None = None, eta: float | None = None) -> Array

Generate samples using DDIM.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

required
steps int | None

Number of DDIM steps (default: self.ddim_steps)

None
eta float | None

DDIM interpolation parameter

None

Returns:

Type Description
Array

Generated samples

sample ¤

sample(n_samples: int, scheduler: str = 'ddim', steps: int | None = None) -> Array

Sample from the model.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

required
scheduler str

Sampling scheduler ("ddim" or "ddpm")

'ddim'
steps int | None

Number of sampling steps

None

Returns:

Type Description
Array

Generated samples

ddim_reverse ¤

ddim_reverse(x0: Array, ddim_steps: int) -> Array

DDIM reverse process (encoding) from x_0 to noise.

This is useful for image editing applications where you want to encode a real image into the noise space and then decode it.

Parameters:

Name Type Description Default
x0 Array

Clean image to encode

required
ddim_steps int

Number of DDIM steps

required

Returns:

Type Description
Array

Encoded noise

DDIM implementation with deterministic sampling and fast inference.

Purpose: Enables much faster sampling (10-20x) than DDPM while maintaining quality, and supports deterministic generation for image editing.

Initialization¤

DDIMModel(
    config: ModelConfiguration,
    *,
    rngs: nnx.Rngs
)

Configuration Parameters:

Parameter Type Default Description
eta float 0.0 Stochasticity (0=deterministic, 1=DDPM)
ddim_steps int 50 Number of sampling steps
skip_type str "uniform" Timestep selection ("uniform" or "quadratic")
noise_steps int 1000 Training timesteps

Methods¤

get_ddim_timesteps(ddim_steps)¤

Get timesteps for DDIM sampling.

Parameters:

  • ddim_steps (int): Number of sampling steps

Returns:

  • jax.Array: Timestep indices for DDIM

Example:

model = DDIMModel(config, rngs=rngs)

# Get 50 uniformly spaced timesteps
timesteps = model.get_ddim_timesteps(50)
print(timesteps)  # [999, 979, 959, ..., 19, 0]
ddim_step(x_t, t, t_prev, predicted_noise, eta=None, *, rngs=None)¤

Single DDIM sampling step.

Parameters:

  • x_t (jax.Array): Current sample at timestep \(t\)
  • t (jax.Array): Current timestep
  • t_prev (jax.Array): Previous timestep
  • predicted_noise (jax.Array): Predicted noise
  • eta (float | None): DDIM parameter (0=deterministic)
  • rngs (nnx.Rngs | None): Random number generators

Returns:

  • jax.Array: Sample at timestep \(t_{prev}\)

Mathematical Formula:

\[ x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \hat{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \epsilon_\theta(x_t, t) + \sigma_t \epsilon \]

Where \(\hat{x}_0\) is the predicted clean sample and \(\sigma_t = \eta \sqrt{(1-\bar{\alpha}_{t-1})/(1-\bar{\alpha}_t)}\sqrt{1-\bar{\alpha}_t/\bar{\alpha}_{t-1}}\)

ddim_sample(n_samples, steps=None, eta=None, *, rngs=None, **kwargs)¤

Generate samples using DDIM.

Parameters:

  • n_samples (int): Number of samples
  • steps (int | None): Number of DDIM steps
  • eta (float | None): Stochasticity parameter
  • rngs (nnx.Rngs | None): Random number generators
  • **kwargs: Additional arguments

Returns:

  • jax.Array: Generated samples

Example:

# Deterministic generation (50 steps)
samples = model.ddim_sample(16, steps=50, eta=0.0, rngs=rngs)

# Stochastic generation (more diversity)
samples = model.ddim_sample(16, steps=50, eta=0.5, rngs=rngs)
ddim_reverse(x0, ddim_steps, *, rngs=None, **kwargs)¤

DDIM reverse process (encoding) from \(x_0\) to noise.

Purpose: Encode a real image into the noise space for image editing.

Parameters:

  • x0 (jax.Array): Clean image to encode
  • ddim_steps (int): Number of reverse steps
  • rngs (nnx.Rngs | None): Random number generators
  • **kwargs: Additional arguments

Returns:

  • jax.Array: Encoded noise

Example:

# Encode real image to noise
real_image = load_image("photo.png")
noise_code = model.ddim_reverse(real_image, ddim_steps=50, rngs=rngs)

# Edit in noise space
edited_noise = noise_code + modification

# Decode back to image
edited_image = model.ddim_sample(1, steps=50, rngs=rngs)

Score-Based Models¤

ScoreDiffusionModel¤

workshop.generative_models.models.diffusion.score.ScoreDiffusionModel ¤

ScoreDiffusionModel(config: ScoreDiffusionConfig, *, rngs: Rngs)

Bases: DiffusionModel

Score-based diffusion model.

This model is based on score matching principles where the model predicts the score (gradient of log-likelihood) instead of noise directly.

Uses nested ScoreDiffusionConfig with: - backbone: BackboneConfig (polymorphic) for the denoising network - noise_schedule: NoiseScheduleConfig for the diffusion schedule - sigma_min: Minimum noise level - sigma_max: Maximum noise level - score_scaling: Score function scaling factor

Backbone type is determined by config.backbone.backbone_type discriminator.

Parameters:

Name Type Description Default
config ScoreDiffusionConfig

ScoreDiffusionConfig with nested backbone and noise_schedule configs. The backbone field accepts any BackboneConfig type and the appropriate backbone is created based on backbone_type.

required
rngs Rngs

Random number generators for initialization.

required

sigma_min instance-attribute ¤

sigma_min = sigma_min

sigma_max instance-attribute ¤

sigma_max = sigma_max

score_scaling instance-attribute ¤

score_scaling = score_scaling

input_dim instance-attribute ¤

input_dim = input_shape

score ¤

score(x: Array, t: Array) -> Array

Compute the score function.

Parameters:

Name Type Description Default
x Array

Input samples

required
t Array

Time steps

required

Returns:

Type Description
Array

Score values

loss ¤

loss(x: Array, *, rngs: Rngs | None = None) -> Array

Compute the score matching loss.

Parameters:

Name Type Description Default
x Array

Input samples

required
rngs Rngs | None

Random number generators

None

Returns:

Type Description
Array

Loss value

sample ¤

sample(num_samples: int, *, rngs: Rngs | None = None, num_steps: int = 1000, return_trajectory: bool = False) -> Array | list[Array]

Generate samples using the reverse SDE.

Parameters:

Name Type Description Default
num_samples int

Number of samples to generate

required
rngs Rngs | None

Random number generators

None
num_steps int

Number of integration steps

1000
return_trajectory bool

If True, return full trajectory

False

Returns:

Type Description
Array | list[Array]

Generated samples or trajectory

denoise ¤

denoise(x: Array, t: Array) -> Array

Predict denoised output.

Parameters:

Name Type Description Default
x Array

Noisy input

required
t Array

Time steps

required

Returns:

Type Description
Array

Denoised output

Score-based diffusion model using score matching.

Purpose: Implements score-based generative modeling where the model predicts the score function (gradient of log-likelihood) instead of noise directly.

Initialization¤

ScoreDiffusionModel(
    *,
    config: ModelConfiguration,
    rngs: nnx.Rngs,
    **kwargs
)

Configuration Parameters:

Parameter Type Default Description
sigma_min float 0.01 Minimum noise level
sigma_max float 1.0 Maximum noise level
score_scaling float 1.0 Score scaling factor

Methods¤

score(x, t)¤

Compute the score function \(\nabla_x \log p_t(x)\).

Parameters:

  • x (jax.Array): Input samples
  • t (jax.Array): Time steps in [0, 1]

Returns:

  • jax.Array: Score values

Mathematical Formula:

\[ \nabla_x \log p_t(x) = -\frac{\epsilon}{\sigma_t} \]
sample(num_samples, *, rngs=None, num_steps=1000, return_trajectory=False)¤

Generate samples using reverse SDE.

Parameters:

  • num_samples (int): Number of samples
  • rngs (nnx.Rngs | None): Random number generators
  • num_steps (int): Number of integration steps
  • return_trajectory (bool): Return full trajectory

Returns:

  • jax.Array | list[jax.Array]: Samples or trajectory

Example:

model = ScoreDiffusionModel(config=config, rngs=rngs)

# Generate samples
samples = model.sample(16, num_steps=1000, rngs=rngs)

# Get full trajectory
trajectory = model.sample(4, num_steps=1000, return_trajectory=True, rngs=rngs)
print(f"Trajectory length: {len(trajectory)}")  # 1000 steps

Latent Diffusion Models¤

LDMModel¤

workshop.generative_models.models.diffusion.latent.LDMModel ¤

LDMModel(config: LatentDiffusionConfig, *, rngs: Rngs)

Bases: DDPMModel

Latent Diffusion Model implementation.

This model combines a VAE for encoding/decoding with a diffusion model that operates in the latent space.

Uses nested LatentDiffusionConfig with: - backbone: BackboneConfig (polymorphic) for the denoising network - noise_schedule: NoiseScheduleConfig for the diffusion schedule - encoder: EncoderConfig for encoding to latent space - decoder: DecoderConfig for decoding from latent space - latent_scale_factor: Scaling factor for latent codes

Parameters:

Name Type Description Default
config LatentDiffusionConfig

LatentDiffusionConfig with nested configs for backbone, noise_schedule, encoder, and decoder. The backbone field accepts any BackboneConfig type and the appropriate backbone is created based on backbone_type.

required
rngs Rngs

Random number generators

required

original_input_dim instance-attribute ¤

original_input_dim = input_shape

latent_dim instance-attribute ¤

latent_dim = latent_dim

scale_factor instance-attribute ¤

scale_factor = latent_scale_factor

encoder instance-attribute ¤

encoder = MLPEncoder(config=encoder, rngs=rngs)

decoder instance-attribute ¤

decoder = MLPDecoder(config=decoder, rngs=rngs)

use_pretrained_vae instance-attribute ¤

use_pretrained_vae = False

input_dim instance-attribute ¤

input_dim = (latent_dim,)

output_dim property ¤

output_dim

Get output dimensions.

encode ¤

encode(x: Array) -> tuple[Array, Array]

Encode input to latent space.

Parameters:

Name Type Description Default
x Array

Input tensor

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (latent_code, posterior_params)

decode ¤

decode(z: Array) -> Array

Decode latent code to output space.

Parameters:

Name Type Description Default
z Array

Latent code

required

Returns:

Type Description
Array

Decoded output

reparameterize ¤

reparameterize(mean: Array, logvar: Array, *, rngs: Rngs) -> Array

Reparameterization trick.

Parameters:

Name Type Description Default
mean Array

Mean of the latent distribution

required
logvar Array

Log variance of the latent distribution

required
rngs Rngs

Random number generators

required

Returns:

Type Description
Array

Sampled latent code

denoise ¤

denoise(x: Array, t: Array, **kwargs) -> Array

Predict noise from noisy input using the backbone.

Parameters:

Name Type Description Default
x Array

Noisy input

required
t Array

Timestep indices

required
**kwargs

Additional arguments for backbone (e.g., conditioning)

{}

Returns:

Type Description
Array

Predicted noise

sample ¤

sample(num_samples: int, *, rngs: Rngs | None = None, return_trajectory: bool = False) -> Array | list[Array]

Sample from the model.

Parameters:

Name Type Description Default
num_samples int

Number of samples to generate

required
rngs Rngs | None

Random number generators

None
return_trajectory bool

If True, return full trajectory

False

Returns:

Type Description
Array | list[Array]

Generated samples or trajectory

loss ¤

loss(x: Array, t: Array | None = None, *, rngs: Rngs | None = None) -> Array

Compute LDM loss.

Parameters:

Name Type Description Default
x Array

Input images

required
t Array | None

Timesteps (optional)

None
rngs Rngs | None

Random number generators

None

Returns:

Type Description
Array

Loss value

Latent Diffusion Model combining VAE and diffusion in latent space.

Purpose: Applies diffusion in a compressed latent space for efficient high-resolution generation. Foundation of Stable Diffusion.

Initialization¤

LDMModel(
    *,
    config: ModelConfiguration,
    rngs: nnx.Rngs,
    **kwargs
)

Configuration Parameters:

Parameter Type Default Description
latent_dim int 8 Latent space dimension
encoder_hidden_dims list[int] [32, 64] Encoder layer sizes
decoder_hidden_dims list[int] [64, 32] Decoder layer sizes
encoder_type str "simple" Encoder type ("simple" or "vae")
decoder_type str "simple" Decoder type
scale_factor float 0.18215 Latent scaling factor

Methods¤

encode(x)¤

Encode input to latent space.

Parameters:

  • x (jax.Array): Input images

Returns:

  • tuple[jax.Array, jax.Array]: (mean, logvar) of latent distribution

Example:

model = LDMModel(config=config, rngs=rngs)

# Encode images to latent space
images = jax.random.normal(rngs.sample(), (8, 64, 64, 3))
mean, logvar = model.encode(images)

print(f"Latent mean: {mean.shape}")      # (8, 16)
print(f"Latent logvar: {logvar.shape}")  # (8, 16)
decode(z)¤

Decode latent code to image space.

Parameters:

  • z (jax.Array): Latent codes

Returns:

  • jax.Array: Decoded images

Example:

# Sample latent code
z = jax.random.normal(rngs.sample(), (8, 16))

# Decode to images
images = model.decode(z)
print(f"Decoded images: {images.shape}")  # (8, 64, 64, 3)
reparameterize(mean, logvar, *, rngs)¤

Reparameterization trick for sampling.

Parameters:

  • mean (jax.Array): Mean of latent distribution
  • logvar (jax.Array): Log variance of latent distribution
  • rngs (nnx.Rngs): Random number generators

Returns:

  • jax.Array: Sampled latent code

Mathematical Formula:

\[ z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \]
sample(num_samples, *, rngs=None, return_trajectory=False)¤

Generate samples (automatically encoded/decoded).

Parameters:

  • num_samples (int): Number of samples
  • rngs (nnx.Rngs | None): Random number generators
  • return_trajectory (bool): Return full trajectory

Returns:

  • jax.Array | list[jax.Array]: Generated images

Example:

# Generate high-resolution images efficiently
samples = model.sample(16, rngs=rngs)
print(f"Generated: {samples.shape}")  # (16, 64, 64, 3)

# Diffusion happens in compressed 16D latent space!
# 8x faster than pixel-space diffusion

Diffusion Transformers¤

DiTModel¤

workshop.generative_models.models.diffusion.dit.DiTModel ¤

DiTModel(config: DiTConfig, backbone_fn: Callable[[DiTConfig, Rngs], Module] | None = None, *, rngs: Rngs)

Bases: GenerativeModel

Diffusion model using Transformer backbone instead of U-Net.

Implements Diffusion Transformers (DiT) which replace the U-Net backbone with a Vision Transformer for improved scalability and performance.

Uses nested DiTConfig with: - noise_schedule: NoiseScheduleConfig for the diffusion schedule - patch_size, hidden_size, depth, num_heads, mlp_ratio: Transformer architecture - num_classes: Number of classes for conditional generation - cfg_scale: Classifier-free guidance scale

Parameters:

Name Type Description Default
config DiTConfig

DiTConfig with DiT-specific parameters and noise_schedule

required
rngs Rngs

Random number generators

required
backbone_fn Callable[[DiTConfig, Rngs], Module] | None

Optional custom backbone function

None

config instance-attribute ¤

config = config

num_classes instance-attribute ¤

num_classes = num_classes

cfg_scale instance-attribute ¤

cfg_scale = cfg_scale

learn_sigma instance-attribute ¤

learn_sigma = learn_sigma

input_dim instance-attribute ¤

input_dim = input_shape

in_channels instance-attribute ¤

in_channels = input_shape[0]

img_size instance-attribute ¤

img_size = input_shape[1]

noise_schedule instance-attribute ¤

noise_schedule: NoiseSchedule = create_noise_schedule(noise_schedule)

betas instance-attribute ¤

betas = betas

alphas instance-attribute ¤

alphas = alphas

alphas_cumprod instance-attribute ¤

alphas_cumprod = alphas_cumprod

backbone instance-attribute ¤

backbone = backbone_fn(config, rngs)

sample_step ¤

sample_step(x_t: Array, t: Array, *, rngs: Rngs | None = None, y: Array | None = None, cfg_scale: float | None = None) -> Array

Single sampling step with optional classifier-free guidance.

Parameters:

Name Type Description Default
x_t Array

Current noisy sample [batch, height, width, channels]

required
t Array

Current timestep [batch]

required
rngs Rngs | None

Random number generators

None
y Array | None

Optional class labels for conditional generation

None
cfg_scale float | None

Classifier-free guidance scale

None

Returns:

Type Description
Array

Denoised sample [batch, height, width, channels]

generate ¤

generate(n_samples: int = 1, *, rngs: Rngs, num_steps: int = 1000, y: Array | None = None, cfg_scale: float | None = None, img_size: int | None = None) -> Array

Generate samples using DiT.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
rngs Rngs

Random number generators

required
num_steps int

Number of diffusion steps

1000
y Array | None

Optional class labels for conditional generation

None
cfg_scale float | None

Classifier-free guidance scale

None
img_size int | None

Image size (uses config default if not specified)

None

Returns:

Type Description
Array

Generated samples [n_samples, height, width, channels]

Diffusion model using Vision Transformer backbone.

Purpose: Replaces U-Net with Transformer for better scalability and state-of-the-art quality at large model sizes.

Initialization¤

DiTModel(
    config: ModelConfiguration,
    *,
    rngs: nnx.Rngs,
    backbone_fn: Optional[Callable] = None,
    **kwargs
)

Configuration Parameters:

Parameter Type Default Description
img_size int 32 Image size
patch_size int 2 Patch size for Vision Transformer
hidden_size int 512 Transformer hidden dimension
depth int 12 Number of transformer layers
num_heads int 8 Number of attention heads
mlp_ratio float 4.0 MLP expansion ratio
num_classes int | None None Number of classes for conditioning
dropout_rate float 0.0 Dropout rate
learn_sigma bool False Learn variance
cfg_scale float 1.0 Classifier-free guidance scale

Methods¤

__call__(x, t, y=None, *, deterministic=False, cfg_scale=None)¤

Forward pass through DiT model.

Parameters:

  • x (jax.Array): Input images (batch, H, W, C)
  • t (jax.Array): Timesteps (batch,)
  • y (jax.Array | None): Optional class labels
  • deterministic (bool): Whether to apply dropout
  • cfg_scale (float | None): Classifier-free guidance scale

Returns:

  • jax.Array: Predicted noise

Example:

model = DiTModel(config, rngs=rngs)

# Forward pass
x = jax.random.normal(rngs.sample(), (4, 32, 32, 3))
t = jnp.array([100, 200, 300, 400])
y = jnp.array([0, 1, 2, 3])  # Class labels

noise_pred = model(x, t, y=y, deterministic=False)
generate(n_samples=1, *, rngs, num_steps=1000, y=None, cfg_scale=None, img_size=None, **kwargs)¤

Generate samples using DiT.

Parameters:

  • n_samples (int): Number of samples
  • rngs (nnx.Rngs): Random number generators
  • num_steps (int): Number of diffusion steps
  • y (jax.Array | None): Class labels for conditional generation
  • cfg_scale (float | None): Classifier-free guidance scale
  • img_size (int | None): Image size
  • **kwargs: Additional arguments

Returns:

  • jax.Array: Generated samples

Example:

# Unconditional generation
samples = model.generate(n_samples=16, rngs=rngs, num_steps=1000)

# Conditional generation with classifier-free guidance
class_labels = jnp.array([i % 10 for i in range(16)])
samples = model.generate(
    n_samples=16,
    y=class_labels,
    cfg_scale=4.0,  # Strong conditioning
    rngs=rngs,
    num_steps=1000
)

Guidance Techniques¤

ClassifierFreeGuidance¤

workshop.generative_models.models.diffusion.guidance.ClassifierFreeGuidance ¤

ClassifierFreeGuidance(guidance_scale: float = 7.5, unconditional_conditioning: Any | None = None)

Classifier-free guidance for conditional diffusion models.

This implements the classifier-free guidance technique that allows trading off between sample diversity and adherence to conditioning.

Parameters:

Name Type Description Default
guidance_scale float

Guidance strength (higher = more conditioning)

7.5
unconditional_conditioning Any | None

Unconditional conditioning token/embedding

None

guidance_scale instance-attribute ¤

guidance_scale = guidance_scale

unconditional_conditioning instance-attribute ¤

unconditional_conditioning = unconditional_conditioning

Classifier-free guidance for conditional generation.

Purpose: Enables strong conditioning without needing a separate classifier by training a single model to handle both conditional and unconditional generation.

Initialization¤

ClassifierFreeGuidance(
    guidance_scale: float = 7.5,
    unconditional_conditioning: Any | None = None
)

Parameters:

Parameter Type Default Description
guidance_scale float 7.5 Guidance strength (higher = stronger conditioning)
unconditional_conditioning Any | None None Unconditional token/embedding

Methods¤

__call__(model, x, t, conditioning, *, rngs=None, **kwargs)¤

Apply classifier-free guidance.

Parameters:

  • model (DiffusionModel): Diffusion model
  • x (jax.Array): Noisy input
  • t (jax.Array): Timesteps
  • conditioning (Any): Conditioning information
  • rngs (nnx.Rngs | None): Random number generators
  • **kwargs: Additional arguments

Returns:

  • jax.Array: Guided noise prediction

Mathematical Formula:

\[ \tilde{\epsilon} = \epsilon_\theta(x_t, \emptyset) + w \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \emptyset)) \]

Example:

from workshop.generative_models.models.diffusion.guidance import ClassifierFreeGuidance

# Create guidance
cfg = ClassifierFreeGuidance(guidance_scale=7.5)

# Use during sampling
x_t = noisy_sample
t = timesteps
conditioning = class_labels

guided_noise = cfg(model, x_t, t, conditioning, rngs=rngs)

ClassifierGuidance¤

workshop.generative_models.models.diffusion.guidance.ClassifierGuidance ¤

ClassifierGuidance(classifier: Module, guidance_scale: float = 1.0, class_label: int | None = None)

Classifier guidance for diffusion models.

Uses a pre-trained classifier to guide the generation process towards desired classes.

Parameters:

Name Type Description Default
classifier Module

Pre-trained classifier model

required
guidance_scale float

Guidance strength

1.0
class_label int | None

Target class label for guidance

None

classifier instance-attribute ¤

classifier = classifier

guidance_scale instance-attribute ¤

guidance_scale = guidance_scale

class_label instance-attribute ¤

class_label = class_label

Classifier guidance using a pre-trained classifier.

Purpose: Uses gradients from a pre-trained classifier to guide generation towards desired classes.

Initialization¤

ClassifierGuidance(
    classifier: nnx.Module,
    guidance_scale: float = 1.0,
    class_label: int | None = None
)

Parameters:

Parameter Type Description
classifier nnx.Module Pre-trained classifier model
guidance_scale float Guidance strength
class_label int | None Target class for guidance

Methods¤

__call__(model, x, t, *, rngs=None, class_label=None, **kwargs)¤

Apply classifier guidance.

Mathematical Formula:

\[ \tilde{\epsilon} = \epsilon_\theta(x_t) - w \sqrt{1 - \bar{\alpha}_t} \nabla_{x_t} \log p_\phi(y | x_t) \]

Example:

from workshop.generative_models.models.diffusion.guidance import ClassifierGuidance

# Load pre-trained classifier
classifier = load_classifier()

# Create classifier guidance
cg = ClassifierGuidance(
    classifier=classifier,
    guidance_scale=1.0,
    class_label=5  # Generate class 5
)

# Use during sampling
guided_noise = cg(model, x_t, t, rngs=rngs)

GuidedDiffusionModel¤

workshop.generative_models.models.diffusion.guidance.GuidedDiffusionModel ¤

GuidedDiffusionModel(config, *, rngs: Rngs, guidance_method: str | None = None, guidance_scale: float = 7.5, classifier: Module | None = None)

Bases: DiffusionModel

Diffusion model with built-in guidance support.

This extends the base diffusion model to support various guidance techniques during generation.

Uses the polymorphic backbone system - backbone type is determined by config.backbone.backbone_type discriminator.

Parameters:

Name Type Description Default
config

Model configuration with nested BackboneConfig. The backbone is created based on backbone_type.

required
rngs Rngs

Random number generators

required
guidance_method str | None

Type of guidance ("classifier_free", "classifier", None)

None
guidance_scale float

Guidance strength

7.5
classifier Module | None

Classifier for classifier guidance

None

guidance_method instance-attribute ¤

guidance_method = guidance_method

guidance_scale instance-attribute ¤

guidance_scale = guidance_scale

guidance instance-attribute ¤

guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale, unconditional_conditioning=getattr(config, 'unconditional_token', None))

guided_sample_step ¤

guided_sample_step(x: Array, t: Array, conditioning: Any | None = None, **kwargs) -> Array

Single sampling step with guidance.

Parameters:

Name Type Description Default
x Array

Current sample

required
t Array

Timesteps

required
conditioning Any | None

Conditioning information

None
**kwargs

Additional arguments

{}

Returns:

Type Description
Array

Guided noise prediction

Note

NNX models store RNGs at init time, no need to pass rngs.

generate ¤

generate(n_samples: int = 1, *, rngs: Rngs | None = None, conditioning: Any | None = None, shape: tuple[int, ...] | None = None, clip_denoised: bool = True, **kwargs) -> Array

Generate samples with guidance.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
rngs Rngs | None

Random number generators

None
conditioning Any | None

Conditioning information for guided generation

None
shape tuple[int, ...] | None

Sample shape

None
clip_denoised bool

Whether to clip denoised samples

True
**kwargs

Additional keyword arguments

{}

Returns:

Type Description
Array

Generated samples

Diffusion model with built-in guidance support.

Purpose: Extends base diffusion model to support various guidance techniques during generation.

Initialization¤

GuidedDiffusionModel(
    config: ModelConfiguration,
    backbone_fn: Callable,
    *,
    rngs: nnx.Rngs,
    guidance_method: str | None = None,
    guidance_scale: float = 7.5,
    classifier: nnx.Module | None = None
)

Parameters:

Parameter Type Description
guidance_method str | None Guidance type ("classifier_free", "classifier", None)
guidance_scale float Guidance strength
classifier nnx.Module | None Classifier for classifier guidance

Example:

from workshop.generative_models.models.diffusion.guidance import GuidedDiffusionModel

# Create model with classifier-free guidance
model = GuidedDiffusionModel(
    config,
    backbone_fn,
    rngs=rngs,
    guidance_method="classifier_free",
    guidance_scale=7.5
)

# Generate with conditioning
samples = model.generate(
    n_samples=16,
    conditioning=class_labels,
    rngs=rngs
)

Guidance Utility Functions¤

apply_guidance(noise_pred_cond, noise_pred_uncond, guidance_scale)¤

Apply classifier-free guidance formula.

Parameters:

  • noise_pred_cond (jax.Array): Conditional noise prediction
  • noise_pred_uncond (jax.Array): Unconditional noise prediction
  • guidance_scale (float): Guidance strength

Returns:

  • jax.Array: Guided noise prediction

Example:

from workshop.generative_models.models.diffusion.guidance import apply_guidance

# Get predictions
noise_cond = model(x_t, t, conditioning=labels, rngs=rngs)
noise_uncond = model(x_t, t, conditioning=None, rngs=rngs)

# Apply guidance
guided = apply_guidance(noise_cond, noise_uncond, guidance_scale=7.5)

linear_guidance_schedule(step, total_steps, start_scale=1.0, end_scale=7.5)¤

Linear guidance scale schedule.

Parameters:

  • step (int): Current step
  • total_steps (int): Total number of steps
  • start_scale (float): Starting guidance scale
  • end_scale (float): Ending guidance scale

Returns:

  • float: Guidance scale for current step

Example:

from workshop.generative_models.models.diffusion.guidance import linear_guidance_schedule

# Gradually increase guidance during sampling
for step in range(total_steps):
    scale = linear_guidance_schedule(step, total_steps, start_scale=1.0, end_scale=7.5)
    # Use scale for this step

cosine_guidance_schedule(step, total_steps, start_scale=1.0, end_scale=7.5)¤

Cosine guidance scale schedule.

Example:

from workshop.generative_models.models.diffusion.guidance import cosine_guidance_schedule

# Use cosine schedule (higher guidance at beginning and end)
for step in range(total_steps):
    scale = cosine_guidance_schedule(step, total_steps)
    # Use scale for this step

Auxiliary Classes¤

SimpleEncoder¤

workshop.generative_models.models.diffusion.latent.SimpleEncoder ¤

SimpleEncoder(input_dim: tuple[int, ...], latent_dim: int, hidden_dims: list | None = None, *, rngs: Rngs)

Bases: Module

Simple encoder for latent diffusion model.

Parameters:

Name Type Description Default
input_dim tuple[int, ...]

Input dimensions (H, W, C) for images or (D,) for vectors

required
latent_dim int

Latent dimension

required
hidden_dims list | None

Hidden layer dimensions

None
rngs Rngs

Random number generators

required

input_dim instance-attribute ¤

input_dim = input_dim

latent_dim instance-attribute ¤

latent_dim = latent_dim

hidden_dims instance-attribute ¤

hidden_dims = hidden_dims or [32, 64]

is_image instance-attribute ¤

is_image = isinstance(input_dim, (tuple, list)) and len(input_dim) >= 2

flat_dim instance-attribute ¤

flat_dim = input_dim[0] * input_dim[1] * input_dim[2]

layers instance-attribute ¤

layers = List([])

mean_layer instance-attribute ¤

mean_layer = Linear(in_features, latent_dim, rngs=rngs)

logvar_layer instance-attribute ¤

logvar_layer = Linear(in_features, latent_dim, rngs=rngs)

Simple MLP encoder for Latent Diffusion Models.

Purpose: Encodes images to latent space with mean and log variance.

Initialization¤

SimpleEncoder(
    input_dim: tuple[int, ...],
    latent_dim: int,
    hidden_dims: list | None = None,
    *,
    rngs: nnx.Rngs
)

SimpleDecoder¤

workshop.generative_models.models.diffusion.latent.SimpleDecoder ¤

SimpleDecoder(latent_dim: int, output_dim: tuple[int, ...], hidden_dims: list | None = None, *, rngs: Rngs | None = None)

Bases: Module

Simple decoder for latent diffusion model.

Parameters:

Name Type Description Default
latent_dim int

Latent dimension

required
output_dim tuple[int, ...]

Output dimensions (H, W, C) for images or (D,) for vectors

required
hidden_dims list | None

Hidden layer dimensions (in reverse order)

None
rngs Rngs | None

Random number generators

None

latent_dim instance-attribute ¤

latent_dim = latent_dim

output_dim instance-attribute ¤

output_dim = output_dim

hidden_dims instance-attribute ¤

hidden_dims = hidden_dims or [64, 32]

is_image instance-attribute ¤

is_image = isinstance(output_dim, (tuple, list)) and len(output_dim) >= 2

flat_dim instance-attribute ¤

flat_dim = output_dim[0] * output_dim[1] * output_dim[2]

layers instance-attribute ¤

layers = List([])

output_layer instance-attribute ¤

output_layer = Linear(in_features, flat_dim, rngs=rngs)

Simple MLP decoder for Latent Diffusion Models.

Purpose: Decodes latent codes back to image space.

Initialization¤

SimpleDecoder(
    latent_dim: int,
    output_dim: tuple[int, ...],
    hidden_dims: list | None = None,
    *,
    rngs: nnx.Rngs
)

Configuration Reference¤

ModelConfiguration for Diffusion Models¤

Complete reference of configuration parameters for all diffusion models.

Base Parameters¤

Parameter Type Required Description
name str Yes Model name
model_class str Yes Model class name
input_dim tuple[int, ...] Yes Input dimensions (H, W, C)
hidden_dims list[int] No Hidden layer dimensions
output_dim int | tuple No Output dimensions
activation str No Activation function
parameters dict No Model-specific parameters

DDPM Parameters¤

{
    "noise_steps": 1000,      # Number of timesteps
    "beta_start": 1e-4,       # Initial noise level
    "beta_end": 0.02,         # Final noise level
    "beta_schedule": "linear" # Noise schedule
}

DDIM Parameters¤

{
    "noise_steps": 1000,      # Training steps
    "ddim_steps": 50,         # Sampling steps
    "eta": 0.0,               # Stochasticity
    "skip_type": "uniform",   # Timestep selection
    "beta_start": 1e-4,
    "beta_end": 0.02
}

Score-Based Parameters¤

{
    "sigma_min": 0.01,        # Minimum noise level
    "sigma_max": 1.0,         # Maximum noise level
    "score_scaling": 1.0,     # Score scaling factor
    "noise_steps": 1000
}

Latent Diffusion Parameters¤

{
    "latent_dim": 16,                    # Latent space dimension
    "encoder_hidden_dims": [64, 128],    # Encoder architecture
    "decoder_hidden_dims": [128, 64],    # Decoder architecture
    "encoder_type": "simple",            # Encoder type
    "scale_factor": 0.18215,             # Latent scaling
    "noise_steps": 1000
}

DiT Parameters¤

{
    "img_size": 32,           # Image size
    "patch_size": 4,          # Patch size
    "hidden_size": 512,       # Transformer dimension
    "depth": 12,              # Number of layers
    "num_heads": 8,           # Attention heads
    "mlp_ratio": 4.0,         # MLP expansion
    "num_classes": 10,        # Number of classes
    "dropout_rate": 0.1,      # Dropout rate
    "learn_sigma": False,     # Learn variance
    "cfg_scale": 2.0,         # Guidance scale
    "noise_steps": 1000
}

Quick Reference¤

Model Selection Guide¤

Model Best For Sampling Speed Memory Quality
DDPMModel Standard use, learning Slow (1000 steps) High ⭐⭐⭐⭐⭐
DDIMModel Fast inference Fast (50 steps) High ⭐⭐⭐⭐
ScoreDiffusionModel Research, flexibility Medium High ⭐⭐⭐⭐
LDMModel High-res, efficiency Fast Medium ⭐⭐⭐⭐
DiTModel Scalability, SOTA Medium Very High ⭐⭐⭐⭐⭐

Common Usage Patterns¤

# Basic DDPM
model = DDPMModel(config, rngs=rngs)
samples = model.generate(16, rngs=rngs)

# Fast DDIM
model = DDIMModel(config, rngs=rngs)
samples = model.ddim_sample(16, steps=50, rngs=rngs)

# Latent Diffusion
model = LDMModel(config=config, rngs=rngs)
samples = model.sample(16, rngs=rngs)

# DiT with conditioning
model = DiTModel(config, rngs=rngs)
samples = model.generate(16, y=labels, cfg_scale=4.0, rngs=rngs)

See Also¤