Skip to content

PPO Trainer¤

Module: artifex.generative_models.training.rl.ppo

The PPO (Proximal Policy Optimization) Trainer provides stable policy gradient training through clipped surrogate objectives and Generalized Advantage Estimation (GAE).

Overview¤

PPO is a state-of-the-art policy gradient method that maintains training stability through:

  • Clipped Surrogate Loss: Prevents large policy updates
  • Generalized Advantage Estimation: Balances bias-variance in advantage computation
  • Value Function Learning: Learns state values for advantage estimation
  • Entropy Bonus: Encourages exploration

Quick Start¤

from artifex.generative_models.training import PPOConfig, PPOTrainer
from flax import nnx
import optax

# Create actor-critic model
model = ActorCriticModel(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(3e-4), wrt=nnx.Param)

# Configure PPO
config = PPOConfig(
    gamma=0.99,
    gae_lambda=0.95,
    clip_param=0.2,
    vf_coeff=0.5,
    entropy_coeff=0.01,
    max_grad_norm=0.5,
)

trainer = PPOTrainer(model, optimizer, config)

# Training step with trajectory
trajectory = {
    "observations": observations,
    "actions": actions,
    "rewards": rewards,
    "values": values,
    "log_probs": old_log_probs,
    "dones": dones,
}
metrics = trainer.train_step(trajectory)

Configuration¤

artifex.generative_models.training.rl.configs.PPOConfig dataclass ¤

PPOConfig(
    gamma: float = 0.99,
    gae_lambda: float = 0.95,
    clip_param: float = 0.2,
    vf_coeff: float = 0.5,
    entropy_coeff: float = 0.01,
    max_grad_norm: float = 0.5,
)

Configuration for Proximal Policy Optimization.

Implements PPO with clipped surrogate objective and GAE.

Attributes:

Name Type Description
gamma float

Discount factor for computing returns. Default 0.99.

gae_lambda float

Lambda for Generalized Advantage Estimation. Default 0.95.

clip_param float

Clipping parameter for surrogate objective. Default 0.2.

vf_coeff float

Coefficient for value function loss. Default 0.5.

entropy_coeff float

Coefficient for entropy bonus. Default 0.01.

max_grad_norm float

Maximum gradient norm for clipping. Default 0.5.

gamma class-attribute instance-attribute ¤

gamma: float = 0.99

gae_lambda class-attribute instance-attribute ¤

gae_lambda: float = 0.95

clip_param class-attribute instance-attribute ¤

clip_param: float = 0.2

vf_coeff class-attribute instance-attribute ¤

vf_coeff: float = 0.5

entropy_coeff class-attribute instance-attribute ¤

entropy_coeff: float = 0.01

max_grad_norm class-attribute instance-attribute ¤

max_grad_norm: float = 0.5

Configuration Options¤

Parameter Type Default Description
gamma float 0.99 Discount factor for GAE
gae_lambda float 0.95 Lambda for GAE (bias-variance trade-off)
clip_param float 0.2 Clipping parameter epsilon
vf_coeff float 0.5 Value function loss coefficient
entropy_coeff float 0.01 Entropy bonus coefficient
max_grad_norm float 0.5 Maximum gradient norm for clipping

Algorithm¤

Clipped Surrogate Objective¤

PPO uses a clipped surrogate objective to prevent large policy updates:

\[\mathcal{L}^{CLIP} = \mathbb{E}\left[\min\left(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t\right)\right]\]

Where \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\) is the probability ratio.

Generalized Advantage Estimation¤

GAE computes advantages using TD residuals:

\[A_t^{GAE} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}\]

Where \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) is the TD residual.

Lambda parameter:

  • \(\lambda = 0\): TD(0) - low variance, high bias
  • \(\lambda = 1\): Monte Carlo - high variance, low bias
  • \(\lambda = 0.95\): Good balance (default)

Value Function Loss¤

\[\mathcal{L}^{VF} = (V_\theta(s_t) - V_t^{target})^2\]

Full Objective¤

\[\mathcal{L} = -\mathcal{L}^{CLIP} + c_1 \mathcal{L}^{VF} - c_2 H(\pi)\]

API Reference¤

artifex.generative_models.training.rl.ppo.PPOTrainer ¤

PPOTrainer(
    model: Module,
    optimizer: Optimizer,
    config: PPOConfig | None = None,
)

Proximal Policy Optimization trainer.

Implements PPO with: - Clipped surrogate objective for stable policy updates - GAE for advantage estimation - Value function fitting - Entropy bonus for exploration - Gradient clipping

The model must be an Actor-Critic that returns (action_logits, value).

Attributes:

Name Type Description
model

Actor-Critic network.

optimizer

Flax NNX optimizer.

config

PPO configuration.

Parameters:

Name Type Description Default
model Module

Actor-Critic network that returns (logits, value).

required
optimizer Optimizer

Flax NNX optimizer for the model.

required
config PPOConfig | None

PPO configuration. Uses defaults if not provided.

None

model instance-attribute ¤

model = model

optimizer instance-attribute ¤

optimizer = optimizer

config instance-attribute ¤

config = config if config is not None else PPOConfig()

compute_gae ¤

compute_gae(
    rewards: Array, values: Array, dones: Array
) -> Array

Compute Generalized Advantage Estimation.

Parameters:

Name Type Description Default
rewards Array

Rewards with shape (T,).

required
values Array

Values with shape (T+1,), including next state value.

required
dones Array

Done flags with shape (T,).

required

Returns:

Type Description
Array

Advantages with shape (T,).

compute_clipped_loss ¤

compute_clipped_loss(
    log_probs: Array,
    old_log_probs: Array,
    advantages: Array,
) -> Array

Compute clipped surrogate policy loss.

Parameters:

Name Type Description Default
log_probs Array

Current policy log probabilities.

required
old_log_probs Array

Old policy log probabilities.

required
advantages Array

Advantage estimates.

required

Returns:

Type Description
Array

Clipped surrogate loss.

compute_value_loss ¤

compute_value_loss(values: Array, returns: Array) -> Array

Compute value function loss (MSE).

Parameters:

Name Type Description Default
values Array

Predicted values.

required
returns Array

Target returns.

required

Returns:

Type Description
Array

Value function loss.

compute_entropy ¤

compute_entropy(log_probs: Array) -> Array

Compute policy entropy.

Parameters:

Name Type Description Default
log_probs Array

Log probabilities with shape (..., num_actions).

required

Returns:

Type Description
Array

Mean entropy.

train_step ¤

train_step(
    batch: dict[str, Array],
) -> tuple[Array, dict[str, Any]]

Perform a single PPO training step.

Parameters:

Name Type Description Default
batch dict[str, Array]

Dictionary containing: - "states": Batch of states. - "actions": Actions taken. - "old_log_probs": Log probs from old policy. - "returns": Target returns. - "advantages": Advantage estimates.

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

Training Metrics¤

Metric Description
policy_loss Clipped surrogate policy loss
value_loss Value function MSE loss
entropy Policy entropy (exploration measure)

Model Requirements¤

PPO requires an actor-critic model that outputs both action probabilities and value estimates:

class ActorCriticModel(nnx.Module):
    def __call__(self, observations) -> tuple[jax.Array, jax.Array]:
        """Forward pass returning (log_probs, values).

        Args:
            observations: State observations.

        Returns:
            Tuple of:
                - log_probs: Action log probabilities, shape (batch, num_actions)
                - values: State value estimates, shape (batch,)
        """
        ...

Hyperparameter Guidelines¤

Clip Parameter (epsilon)¤

  • 0.1-0.2: Standard range, 0.2 is most common
  • Lower values → more conservative updates
  • Higher values → larger policy changes allowed

GAE Lambda¤

  • 0.95: Good default for most tasks
  • 0.99: Lower bias, higher variance (longer-horizon tasks)
  • 0.9: Higher bias, lower variance (shorter-horizon tasks)

Value Function Coefficient¤

  • 0.5: Standard choice
  • Higher values → more emphasis on accurate value estimation

Use Cases¤

PPO is recommended for:

  • Complex tasks: When REINFORCE is too unstable
  • Continuous control: Robotics, physics simulations
  • Games: Atari, board games, video games
  • Large models: When you can afford the value network memory

For memory-constrained settings, consider GRPO.