PPO Trainer¤
Module: artifex.generative_models.training.rl.ppo
The PPO (Proximal Policy Optimization) Trainer provides stable policy gradient training through clipped surrogate objectives and Generalized Advantage Estimation (GAE).
Overview¤
PPO is a state-of-the-art policy gradient method that maintains training stability through:
- Clipped Surrogate Loss: Prevents large policy updates
- Generalized Advantage Estimation: Balances bias-variance in advantage computation
- Value Function Learning: Learns state values for advantage estimation
- Entropy Bonus: Encourages exploration
Quick Start¤
from artifex.generative_models.training import PPOConfig, PPOTrainer
from flax import nnx
import optax
# Create actor-critic model
model = ActorCriticModel(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(3e-4), wrt=nnx.Param)
# Configure PPO
config = PPOConfig(
gamma=0.99,
gae_lambda=0.95,
clip_param=0.2,
vf_coeff=0.5,
entropy_coeff=0.01,
max_grad_norm=0.5,
)
trainer = PPOTrainer(model, optimizer, config)
# Training step with trajectory
trajectory = {
"observations": observations,
"actions": actions,
"rewards": rewards,
"values": values,
"log_probs": old_log_probs,
"dones": dones,
}
metrics = trainer.train_step(trajectory)
Configuration¤
artifex.generative_models.training.rl.configs.PPOConfig
dataclass
¤
PPOConfig(
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_param: float = 0.2,
vf_coeff: float = 0.5,
entropy_coeff: float = 0.01,
max_grad_norm: float = 0.5,
)
Configuration for Proximal Policy Optimization.
Implements PPO with clipped surrogate objective and GAE.
Attributes:
| Name | Type | Description |
|---|---|---|
gamma |
float
|
Discount factor for computing returns. Default 0.99. |
gae_lambda |
float
|
Lambda for Generalized Advantage Estimation. Default 0.95. |
clip_param |
float
|
Clipping parameter for surrogate objective. Default 0.2. |
vf_coeff |
float
|
Coefficient for value function loss. Default 0.5. |
entropy_coeff |
float
|
Coefficient for entropy bonus. Default 0.01. |
max_grad_norm |
float
|
Maximum gradient norm for clipping. Default 0.5. |
Configuration Options¤
| Parameter | Type | Default | Description |
|---|---|---|---|
gamma |
float |
0.99 |
Discount factor for GAE |
gae_lambda |
float |
0.95 |
Lambda for GAE (bias-variance trade-off) |
clip_param |
float |
0.2 |
Clipping parameter epsilon |
vf_coeff |
float |
0.5 |
Value function loss coefficient |
entropy_coeff |
float |
0.01 |
Entropy bonus coefficient |
max_grad_norm |
float |
0.5 |
Maximum gradient norm for clipping |
Algorithm¤
Clipped Surrogate Objective¤
PPO uses a clipped surrogate objective to prevent large policy updates:
Where \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\) is the probability ratio.
Generalized Advantage Estimation¤
GAE computes advantages using TD residuals:
Where \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) is the TD residual.
Lambda parameter:
- \(\lambda = 0\): TD(0) - low variance, high bias
- \(\lambda = 1\): Monte Carlo - high variance, low bias
- \(\lambda = 0.95\): Good balance (default)
Value Function Loss¤
Full Objective¤
API Reference¤
artifex.generative_models.training.rl.ppo.PPOTrainer
¤
Proximal Policy Optimization trainer.
Implements PPO with: - Clipped surrogate objective for stable policy updates - GAE for advantage estimation - Value function fitting - Entropy bonus for exploration - Gradient clipping
The model must be an Actor-Critic that returns (action_logits, value).
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Actor-Critic network. |
|
optimizer |
Flax NNX optimizer. |
|
config |
PPO configuration. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Actor-Critic network that returns (logits, value). |
required |
optimizer
|
Optimizer
|
Flax NNX optimizer for the model. |
required |
config
|
PPOConfig | None
|
PPO configuration. Uses defaults if not provided. |
None
|
compute_gae
¤
Compute Generalized Advantage Estimation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rewards
|
Array
|
Rewards with shape (T,). |
required |
values
|
Array
|
Values with shape (T+1,), including next state value. |
required |
dones
|
Array
|
Done flags with shape (T,). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Advantages with shape (T,). |
compute_clipped_loss
¤
Compute clipped surrogate policy loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_probs
|
Array
|
Current policy log probabilities. |
required |
old_log_probs
|
Array
|
Old policy log probabilities. |
required |
advantages
|
Array
|
Advantage estimates. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Clipped surrogate loss. |
compute_value_loss
¤
compute_entropy
¤
train_step
¤
Perform a single PPO training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Array]
|
Dictionary containing: - "states": Batch of states. - "actions": Actions taken. - "old_log_probs": Log probs from old policy. - "returns": Target returns. - "advantages": Advantage estimates. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
Training Metrics¤
| Metric | Description |
|---|---|
policy_loss |
Clipped surrogate policy loss |
value_loss |
Value function MSE loss |
entropy |
Policy entropy (exploration measure) |
Model Requirements¤
PPO requires an actor-critic model that outputs both action probabilities and value estimates:
class ActorCriticModel(nnx.Module):
def __call__(self, observations) -> tuple[jax.Array, jax.Array]:
"""Forward pass returning (log_probs, values).
Args:
observations: State observations.
Returns:
Tuple of:
- log_probs: Action log probabilities, shape (batch, num_actions)
- values: State value estimates, shape (batch,)
"""
...
Hyperparameter Guidelines¤
Clip Parameter (epsilon)¤
- 0.1-0.2: Standard range, 0.2 is most common
- Lower values → more conservative updates
- Higher values → larger policy changes allowed
GAE Lambda¤
- 0.95: Good default for most tasks
- 0.99: Lower bias, higher variance (longer-horizon tasks)
- 0.9: Higher bias, lower variance (shorter-horizon tasks)
Value Function Coefficient¤
- 0.5: Standard choice
- Higher values → more emphasis on accurate value estimation
Use Cases¤
PPO is recommended for:
- Complex tasks: When REINFORCE is too unstable
- Continuous control: Robotics, physics simulations
- Games: Atari, board games, video games
- Large models: When you can afford the value network memory
For memory-constrained settings, consider GRPO.
Related Documentation¤
- RL Training Guide - Comprehensive RL training guide
- REINFORCE Trainer - Simpler baseline algorithm
- GRPO Trainer - Memory-efficient alternative
- DPO Trainer - Preference-based learning