DPO Trainer¤
Module: artifex.generative_models.training.rl.dpo
The DPO (Direct Preference Optimization) Trainer enables learning from preference pairs without requiring an explicit reward model or RL optimization loop.
Overview¤
DPO directly optimizes the policy to prefer chosen responses over rejected ones:
- No Reward Model: Learns directly from preferences
- Stable Training: Uses supervised-learning-style updates
- SimPO Support: Reference-free variant for simpler setup
- Label Smoothing: Robustness to noisy preferences
Quick Start¤
from artifex.generative_models.training import DPOConfig, DPOTrainer
from flax import nnx
import optax
# Create policy model
model = PolicyModel(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-5), wrt=nnx.Param)
# Configure DPO
config = DPOConfig(
beta=0.1,
label_smoothing=0.0,
reference_free=False,
)
trainer = DPOTrainer(model, optimizer, config)
# Training with preference pairs
batch = {
"chosen_log_probs": chosen_log_probs,
"rejected_log_probs": rejected_log_probs,
"ref_chosen_log_probs": ref_chosen,
"ref_rejected_log_probs": ref_rejected,
}
metrics = trainer.train_step(batch)
Configuration¤
artifex.generative_models.training.rl.configs.DPOConfig
dataclass
¤
Configuration for Direct Preference Optimization.
DPO enables preference learning without an explicit reward model. SimPO mode (reference_free=True) eliminates the need for a reference model.
Attributes:
| Name | Type | Description |
|---|---|---|
beta |
float
|
Reward scaling parameter. Higher values = stronger preference. Default 0.1. |
label_smoothing |
float
|
Label smoothing for preference loss. Default 0.0. |
reference_free |
bool
|
Whether to use SimPO-style reference-free training. When True, no reference model is needed. Default False. |
Configuration Options¤
| Parameter | Type | Default | Description |
|---|---|---|---|
beta |
float |
0.1 |
Temperature parameter for reward scaling |
label_smoothing |
float |
0.0 |
Label smoothing for robustness |
reference_free |
bool |
False |
Use SimPO (reference-free) mode |
Algorithm¤
Standard DPO¤
DPO optimizes the Bradley-Terry preference model:
Where:
- \(y_w\) is the preferred (chosen) response
- \(y_l\) is the rejected response
- \(\pi_{ref}\) is the reference policy (frozen)
- \(\beta\) controls the implicit reward scaling
SimPO (Reference-Free)¤
SimPO eliminates the reference model by using length-normalized log probabilities:
Enable with reference_free=True:
Label Smoothing¤
For robustness to noisy preference labels:
API Reference¤
artifex.generative_models.training.rl.dpo.DPOTrainer
¤
DPOTrainer(
model: Module,
reference_model: Module | None,
optimizer: Optimizer,
config: DPOConfig | None = None,
)
Direct Preference Optimization trainer.
Implements DPO for preference learning: - Learns from preference pairs (chosen, rejected) - Uses log-ratio between policy and reference model - SimPO mode eliminates need for reference model
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Policy model to train. |
|
reference_model |
Frozen reference model (None in SimPO mode). |
|
optimizer |
Flax NNX optimizer. |
|
config |
DPO configuration. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Policy model to train. |
required |
reference_model
|
Module | None
|
Frozen reference model. Can be None if config.reference_free=True (SimPO mode). |
required |
optimizer
|
Optimizer
|
Flax NNX optimizer for the model. |
required |
config
|
DPOConfig | None
|
DPO configuration. Uses defaults if not provided. |
None
|
compute_log_probs
¤
Compute log probabilities for sequences.
For simplicity, this computes the mean log probability across the sequence. In practice, you'd want to compute per-token log probs and sum/average appropriately.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to compute log probs with. |
required |
sequences
|
Array
|
Input sequences with shape (batch_size, seq_len). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Log probabilities with shape (batch_size,). |
compute_log_ratios
¤
compute_loss
¤
Compute DPO loss.
DPO loss: -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Array]
|
Dictionary containing: - "chosen": Chosen sequences. - "rejected": Rejected sequences. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
train_step
¤
Training Metrics¤
| Metric | Description |
|---|---|
dpo_loss |
DPO/SimPO loss value |
reward_accuracy |
Fraction where chosen > rejected reward |
Data Format¤
Standard DPO¤
Requires log probabilities from both policy and reference model:
batch = {
# Policy log probs
"chosen_log_probs": policy_chosen, # shape: (batch,)
"rejected_log_probs": policy_rejected, # shape: (batch,)
# Reference model log probs (frozen)
"ref_chosen_log_probs": ref_chosen, # shape: (batch,)
"ref_rejected_log_probs": ref_rejected, # shape: (batch,)
}
SimPO (Reference-Free)¤
Only requires policy log probabilities:
batch = {
"chosen_log_probs": chosen_log_probs,
"rejected_log_probs": rejected_log_probs,
# No reference model log probs needed
}
Beta Parameter¤
The beta parameter controls the sharpness of the preference:
- Lower beta (0.01-0.05): Softer preferences, more exploration
- Standard beta (0.1): Default, good balance
- Higher beta (0.5-1.0): Sharper preferences, stronger alignment
Preparing Preference Data¤
def prepare_dpo_batch(
model,
ref_model,
prompts,
chosen_responses,
rejected_responses,
):
"""Prepare batch for DPO training.
Args:
model: Policy model being trained
ref_model: Frozen reference model
prompts: Input prompts
chosen_responses: Preferred completions
rejected_responses: Non-preferred completions
Returns:
Batch dict for DPO trainer
"""
# Compute log probs from policy
chosen_log_probs = compute_log_probs(model, prompts, chosen_responses)
rejected_log_probs = compute_log_probs(model, prompts, rejected_responses)
# Compute log probs from reference (no gradients)
ref_chosen = compute_log_probs(ref_model, prompts, chosen_responses)
ref_rejected = compute_log_probs(ref_model, prompts, rejected_responses)
return {
"chosen_log_probs": chosen_log_probs,
"rejected_log_probs": rejected_log_probs,
"ref_chosen_log_probs": ref_chosen,
"ref_rejected_log_probs": ref_rejected,
}
Use Cases¤
DPO is recommended for:
- Alignment: When you have human preference data
- No reward model: Simpler than RLHF pipeline
- Fine-tuning LLMs: Preference tuning for language models
- Image generation: Preference-based image quality tuning
Comparison with RL Methods¤
| Aspect | DPO | PPO/GRPO |
|---|---|---|
| Requires reward model | No | Yes |
| Training stability | High | Medium |
| Sample efficiency | High | Lower |
| Flexibility | Less | More |
| Online learning | No | Yes |
Related Documentation¤
- RL Training Guide - Comprehensive RL training guide
- PPO Trainer - Policy gradient with value function
- GRPO Trainer - Memory-efficient RL
- REINFORCE Trainer - Basic policy gradient