REINFORCE Trainer¤
Module: artifex.generative_models.training.rl.reinforce
The REINFORCE Trainer implements the basic policy gradient algorithm with variance reduction through return normalization and entropy bonus for exploration.
Overview¤
REINFORCE is the simplest policy gradient algorithm, computing gradient updates based on discounted returns. This implementation includes:
- Discounted Returns: Efficient backward pass computation
- Return Normalization: Variance reduction for stable training
- Entropy Bonus: Encourages exploration and prevents premature convergence
Quick Start¤
from artifex.generative_models.training import (
REINFORCEConfig,
REINFORCETrainer,
)
from flax import nnx
import optax
# Create model and optimizer
model = PolicyModel(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
# Configure REINFORCE
config = REINFORCEConfig(
gamma=0.99,
normalize_returns=True,
entropy_coeff=0.01,
)
trainer = REINFORCETrainer(model, optimizer, config)
# Training step
batch = {
"observations": observations,
"actions": actions,
"rewards": rewards,
}
metrics = trainer.train_step(batch)
Configuration¤
artifex.generative_models.training.rl.configs.REINFORCEConfig
dataclass
¤
REINFORCEConfig(
gamma: float = 0.99,
normalize_returns: bool = True,
entropy_coeff: float = 0.01,
)
Configuration for REINFORCE policy gradient algorithm.
Attributes:
| Name | Type | Description |
|---|---|---|
gamma |
float
|
Discount factor for computing returns. Default 0.99. |
normalize_returns |
bool
|
Whether to normalize returns for variance reduction. |
entropy_coeff |
float
|
Coefficient for entropy bonus to encourage exploration. |
Configuration Options¤
| Parameter | Type | Default | Description |
|---|---|---|---|
gamma |
float |
0.99 |
Discount factor for computing returns |
normalize_returns |
bool |
True |
Normalize returns for variance reduction |
entropy_coeff |
float |
0.01 |
Coefficient for entropy bonus |
Algorithm¤
REINFORCE computes the policy gradient:
Where \(G_t = \sum_{k=t}^{T} \gamma^{k-t} r_k\) is the discounted return from time step \(t\).
Variance Reduction¤
With normalize_returns=True, returns are normalized:
Entropy Bonus¤
The entropy bonus encourages exploration:
Where \(H(\pi) = -\sum \pi(a|s) \log \pi(a|s)\) is the policy entropy.
API Reference¤
artifex.generative_models.training.rl.reinforce.REINFORCETrainer
¤
REINFORCETrainer(
model: Module,
optimizer: Optimizer,
config: REINFORCEConfig | None = None,
)
REINFORCE policy gradient trainer.
Implements the REINFORCE algorithm with variance reduction techniques: - Discounted returns for credit assignment - Return normalization to stabilize gradients - Entropy bonus to encourage exploration
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Policy network (must output action logits). |
|
optimizer |
Flax NNX optimizer. |
|
config |
REINFORCE configuration. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Policy network that outputs action logits. |
required |
optimizer
|
Optimizer
|
Flax NNX optimizer for the model. |
required |
config
|
REINFORCEConfig | None
|
REINFORCE configuration. Uses defaults if not provided. |
None
|
compute_returns
¤
normalize_returns
¤
compute_loss
¤
Compute REINFORCE loss.
Loss = -E[log(pi(a|s)) * R] - entropy_coeff * H(pi)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
Array
|
Batch of states with shape (batch_size, ...). |
required |
actions
|
Array
|
Batch of actions taken with shape (batch_size,). |
required |
returns
|
Array
|
Discounted returns with shape (batch_size,). |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
train_step
¤
Perform a single training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trajectory
|
dict[str, Array]
|
Dictionary containing: - "states": Batch of states. - "actions": Actions taken. - "rewards": Rewards received. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
Training Metrics¤
| Metric | Description |
|---|---|
policy_loss |
Policy gradient loss (negated for minimization) |
Use Cases¤
REINFORCE is best suited for:
- Simple baselines: Quick experiments before more sophisticated methods
- Low-dimensional action spaces: Works well when action space is small
- Research: Understanding policy gradient fundamentals
For more stable training, consider PPO or GRPO.
Related Documentation¤
- RL Training Guide - Comprehensive RL training guide
- PPO Trainer - More stable policy gradient training
- GRPO Trainer - Memory-efficient critic-free RL