Skip to content

REINFORCE Trainer¤

Module: artifex.generative_models.training.rl.reinforce

The REINFORCE Trainer implements the basic policy gradient algorithm with variance reduction through return normalization and entropy bonus for exploration.

Overview¤

REINFORCE is the simplest policy gradient algorithm, computing gradient updates based on discounted returns. This implementation includes:

  • Discounted Returns: Efficient backward pass computation
  • Return Normalization: Variance reduction for stable training
  • Entropy Bonus: Encourages exploration and prevents premature convergence

Quick Start¤

from artifex.generative_models.training import (
    REINFORCEConfig,
    REINFORCETrainer,
)
from flax import nnx
import optax

# Create model and optimizer
model = PolicyModel(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)

# Configure REINFORCE
config = REINFORCEConfig(
    gamma=0.99,
    normalize_returns=True,
    entropy_coeff=0.01,
)

trainer = REINFORCETrainer(model, optimizer, config)

# Training step
batch = {
    "observations": observations,
    "actions": actions,
    "rewards": rewards,
}
metrics = trainer.train_step(batch)

Configuration¤

artifex.generative_models.training.rl.configs.REINFORCEConfig dataclass ¤

REINFORCEConfig(
    gamma: float = 0.99,
    normalize_returns: bool = True,
    entropy_coeff: float = 0.01,
)

Configuration for REINFORCE policy gradient algorithm.

Attributes:

Name Type Description
gamma float

Discount factor for computing returns. Default 0.99.

normalize_returns bool

Whether to normalize returns for variance reduction.

entropy_coeff float

Coefficient for entropy bonus to encourage exploration.

gamma class-attribute instance-attribute ¤

gamma: float = 0.99

normalize_returns class-attribute instance-attribute ¤

normalize_returns: bool = True

entropy_coeff class-attribute instance-attribute ¤

entropy_coeff: float = 0.01

Configuration Options¤

Parameter Type Default Description
gamma float 0.99 Discount factor for computing returns
normalize_returns bool True Normalize returns for variance reduction
entropy_coeff float 0.01 Coefficient for entropy bonus

Algorithm¤

REINFORCE computes the policy gradient:

\[\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right]\]

Where \(G_t = \sum_{k=t}^{T} \gamma^{k-t} r_k\) is the discounted return from time step \(t\).

Variance Reduction¤

With normalize_returns=True, returns are normalized:

\[\hat{G}_t = \frac{G_t - \mu_G}{\sigma_G + \epsilon}\]

Entropy Bonus¤

The entropy bonus encourages exploration:

\[\mathcal{L} = -\mathbb{E}[\log \pi(a|s) \cdot G] - \lambda_H H(\pi)\]

Where \(H(\pi) = -\sum \pi(a|s) \log \pi(a|s)\) is the policy entropy.

API Reference¤

artifex.generative_models.training.rl.reinforce.REINFORCETrainer ¤

REINFORCETrainer(
    model: Module,
    optimizer: Optimizer,
    config: REINFORCEConfig | None = None,
)

REINFORCE policy gradient trainer.

Implements the REINFORCE algorithm with variance reduction techniques: - Discounted returns for credit assignment - Return normalization to stabilize gradients - Entropy bonus to encourage exploration

Attributes:

Name Type Description
model

Policy network (must output action logits).

optimizer

Flax NNX optimizer.

config

REINFORCE configuration.

Parameters:

Name Type Description Default
model Module

Policy network that outputs action logits.

required
optimizer Optimizer

Flax NNX optimizer for the model.

required
config REINFORCEConfig | None

REINFORCE configuration. Uses defaults if not provided.

None

model instance-attribute ¤

model = model

optimizer instance-attribute ¤

optimizer = optimizer

config instance-attribute ¤

config = config if config is not None else REINFORCEConfig()

compute_returns ¤

compute_returns(rewards: Array) -> Array

Compute discounted returns from rewards.

Parameters:

Name Type Description Default
rewards Array

Array of rewards with shape (T,).

required

Returns:

Type Description
Array

Array of discounted returns with shape (T,).

normalize_returns ¤

normalize_returns(returns: Array) -> Array

Normalize returns to zero mean and unit variance.

Parameters:

Name Type Description Default
returns Array

Array of returns to normalize.

required

Returns:

Type Description
Array

Normalized returns.

compute_loss ¤

compute_loss(
    states: Array, actions: Array, returns: Array
) -> tuple[Array, dict[str, Any]]

Compute REINFORCE loss.

Loss = -E[log(pi(a|s)) * R] - entropy_coeff * H(pi)

Parameters:

Name Type Description Default
states Array

Batch of states with shape (batch_size, ...).

required
actions Array

Batch of actions taken with shape (batch_size,).

required
returns Array

Discounted returns with shape (batch_size,).

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

train_step ¤

train_step(
    trajectory: dict[str, Array],
) -> tuple[Array, dict[str, Any]]

Perform a single training step.

Parameters:

Name Type Description Default
trajectory dict[str, Array]

Dictionary containing: - "states": Batch of states. - "actions": Actions taken. - "rewards": Rewards received.

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

Training Metrics¤

Metric Description
policy_loss Policy gradient loss (negated for minimization)

Use Cases¤

REINFORCE is best suited for:

  • Simple baselines: Quick experiments before more sophisticated methods
  • Low-dimensional action spaces: Works well when action space is small
  • Research: Understanding policy gradient fundamentals

For more stable training, consider PPO or GRPO.