Skip to content

Conversation

TrailChai
Copy link

This commit introduces a new example for Reinforcement Learning from Human Feedback (RLHF).

It includes:

  • `examples/rl/rlhf_dummy_demo.py`: A Python script demonstrating a simple RLHF loop with a dummy environment, a policy model, and a reward model, using Keras with the JAX backend.
  • `examples/rl/md/rlhf_dummy_demo.md`: A Markdown guide explaining the RLHF concept and the implementation details of the demo script.
  • `examples/rl/README.md`: A new README for the RL examples section, now including the RLHF demo.

Note: The Python demo script (`rlhf_dummy_demo.py`) currently experiences timeout issues during the training loop in the development environment, even with significantly reduced computational load. This is documented in the guide and README. The code serves as a structural example of implementing the RLHF components.

This commit introduces a new example for Reinforcement Learning from Human Feedback (RLHF).

It includes:
- \`examples/rl/rlhf_dummy_demo.py\`: A Python script demonstrating a simple RLHF loop with a dummy environment, a policy model, and a reward model, using Keras with the JAX backend.
- \`examples/rl/md/rlhf_dummy_demo.md\`: A Markdown guide explaining the RLHF concept and the implementation details of the demo script.
- \`examples/rl/README.md\`: A new README for the RL examples section, now including the RLHF demo.

Note: The Python demo script (\`rlhf_dummy_demo.py\`) currently experiences timeout issues during the training loop in the development environment, even with significantly reduced computational load. This is documented in the guide and README. The code serves as a structural example of implementing the RLHF components.
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Yasir! just keep the .py file for now. once it is approved we can generate the .ipynb files

policy_model_params["non_trainable"],
state_input
)
actual_predictions_tensor = predictions_tuple[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we assuming batch size is 1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The code as written assumes a batch size of 1 for all model inputs and gradient calculations

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a simplification for the demo's clarity and to manage complexity, especially since REINFORCE-style updates can be done on single trajectories. A more advanced setup would definitely use batching across multiple episodes or from a replay buffer for stability and efficiency.

Does that make sense in the context of this simplified demo?

episode_policy_losses.append(current_policy_loss)
policy_grads_step = policy_grads_dict_step["trainable"]
# Accumulate policy gradients
for i, grad in enumerate(policy_grads_step):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For potentially improving performance
policy_grads_accum = jax.tree_map(lambda acc, new: acc + new if new is not None else acc, policy_grads_accum, policy_grads_step)

Copy link
Author

@TrailChai TrailChai Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have refactored policy gradients and reward gradients accumulations using jax.tree_map.

actual_predictions_tensor = predictions_tuple[0]
action_probs = actual_predictions_tensor[0]
log_prob = jnp.log(action_probs[action] + 1e-7)
return -log_prob * predicted_reward_value_stopped
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this predicted_reward_value is just R(s,a), then it's using the immediate predicted reward.
it's a very naive and generally ineffective way to train a policy. Thee log_prob should be multiplied by the cumulative discounted future reward (Return, G_t).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit c212593 addresses this.

TrailChai and others added 7 commits June 3, 2025 11:35
…rewards.

This commit refines the RLHF demo example (`examples/rl/rlhf_dummy_demo.py`)
to use discounted cumulative actual rewards (G_t) for policy gradient
calculations, aligning it with the REINFORCE algorithm.

Changes include:
- Added a `calculate_discounted_returns` helper function.
- Modified the `rlhf_training_loop` to collect trajectories (states,
  actions, rewards) and compute G_t for each step at the end of an episode.
- Updated the policy loss function to use these G_t values instead of
  immediate predicted rewards.
- The reward model training logic remains focused on predicting immediate
  rewards based on simulated human feedback (environment reward in this demo).
- Updated the corresponding RLHF guide (`examples/rl/md/rlhf_dummy_demo.md`)
  to explain these changes and provide updated code snippets.

The timeout issues with the script in the development environment persist,
but the code now better reflects a standard policy gradient approach.
@divyashreepathihalli
Copy link
Collaborator

the .md files are automatically generated. SO you might want to move the explanation content part to .py

TrailChai added 2 commits June 6, 2025 23:06
Moving the explanation piece to .py from .md
@TrailChai
Copy link
Author

I have deleted the .md files and added relevant documentation pieces to .py file

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new example demonstrating a simplified Reinforcement Learning from Human Feedback (RLHF) loop using Keras with a JAX backend. The example includes a Python script with an embedded markdown guide.

My review focuses on improving clarity, correctness, and maintainability. Key findings include:

  • Inconsistent terminology (RLAIF vs. RLHF) that should be standardized.
  • A critical discrepancy between the documentation and the code regarding the policy update algorithm.
  • A bug in the logging logic that prevents any training progress from being displayed with the current settings.
  • A minor performance issue in a helper function and a formatting error in the documentation.

I've provided specific suggestions to address these points.

Comment on lines +11 to +12
# Reinforcement Learning from AI Feedback(RLAIF) - Demo Guide
This guide explains the concept of Reinforcement Learning from AI Feedback (RLAIF) and walks through the components of the accompanying script `rlhf_demo.py`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The guide's main title and introduction use "RLAIF" (Reinforcement Learning from AI Feedback), but the PR and the rest of the guide refer to "RLHF" (Reinforcement Learning from Human Feedback). This is inconsistent and confusing. Please use "RLHF" consistently. The metadata in lines 2 and 6 should also be updated to reflect this change.

Suggested change
# Reinforcement Learning from AI Feedback(RLAIF) - Demo Guide
This guide explains the concept of Reinforcement Learning from AI Feedback (RLAIF) and walks through the components of the accompanying script `rlhf_demo.py`.
# Reinforcement Learning from Human Feedback (RLHF) - Demo Guide
This guide explains the concept of Reinforcement Learning from Human Feedback (RLHF) and walks through the components of the accompanying script `rlhf_demo.py`.

total_reward_loss_avg += mean_episode_reward_loss
loss_count_avg +=1

if (episode + 1) % 100 == 0 and loss_count_avg > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The training progress is logged only when (episode + 1) % 100 == 0. However, the training loop is configured to run for only 10 episodes (both in the function default and the __main__ block call). As a result, no progress will be printed during the run. The logging interval should be adjusted to be useful for a short demo run.

Suggested change
if (episode + 1) % 100 == 0 and loss_count_avg > 0:
if (episode + 1) % 5 == 0 and loss_count_avg > 0:

Comment on lines +273 to +280
4. **Policy Model Training (REINFORCE-style):**
* **At the end of each episode:**
* The `calculate_discounted_returns` function is called with the list of `true_reward`s collected during the episode to compute the discounted cumulative reward (G_t) for each step.
* These returns are typically normalized (subtract mean, divide by standard deviation) to stabilize training.
* The code then iterates through each step `t` of the collected trajectory.
* For each step, the `policy_loss_fn` is called. Its loss is calculated as `-log_prob(action_t) * G_t`. This means the update encourages actions that led to higher overall discounted future rewards.
* Gradients for the policy model are computed for each step and accumulated across the entire episode.
* **Gradient Application:** The accumulated policy gradients are averaged over the number of steps and applied to the policy model using its optimizer. This update rule aims to increase the probability of actions that lead to good long-term outcomes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The documentation here describes a REINFORCE-style policy update where discounted returns (G_t) are calculated at the end of an episode from the true_rewards. However, the actual implementation in rlhf_training_loop performs an update at each step using the predicted_reward_value from the reward model as the signal for the policy update (see lines 192-208). The calculate_discounted_returns function is defined but never called.

This is a significant discrepancy that misleads users about how the algorithm is implemented. The documentation should be updated to accurately reflect the code's actor-critic-like implementation, where the reward model's output is used as the advantage signal for the policy update.

Comment on lines +47 to +53
def calculate_discounted_returns(rewards, gamma=0.99):
returns = []
cumulative_return = 0
for r in reversed(rewards):
cumulative_return = r + gamma * cumulative_return
returns.insert(0, cumulative_return)
return jnp.array(returns)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of calculate_discounted_returns uses returns.insert(0, ...) inside a loop. Inserting at the beginning of a Python list is an O(n) operation, making the whole function O(n^2) where n is the number of rewards. While this function is not currently used, the PR description mentions timeout issues, so performance is a valid concern.

A more performant O(n) approach is to append to the list and then reverse it once at the end.

Suggested change
def calculate_discounted_returns(rewards, gamma=0.99):
returns = []
cumulative_return = 0
for r in reversed(rewards):
cumulative_return = r + gamma * cumulative_return
returns.insert(0, cumulative_return)
return jnp.array(returns)
def calculate_discounted_returns(rewards, gamma=0.99):
# A more performant O(n) approach is to append to the list and then reverse it.
returns = []
cumulative_return = 0
for r in reversed(rewards):
cumulative_return = r + gamma * cumulative_return
returns.append(cumulative_return)
return jnp.array(returns[::-1])

* **Gradient Application:** The accumulated policy gradients are averaged over the number of steps and applied to the policy model using its optimizer. This update rule aims to increase the probability of actions that lead to good long-term outcomes.
5. **Logging:** Average policy and reward losses for the episode are printed periodically.
The core idea of RLHF is still present: we have a reward model that *could* be trained from human preferences. However, the policy update mechanism has shifted. Instead of using the reward model's output directly as the advantage signal for each step (as in the previous version of the script), the policy now learns from the actual discounted returns experienced in the episode, which is a more standard RL approach when actual rewards (or good proxies like `true_reward` here) are available for the full trajectory. In a full RLHF system, `episode_true_rewards` might themselves be replaced or augmented by the reward model's predictions if no dense "true" reward exists.
8. **Logging:** Periodically, average losses are printed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The numbering in this list of key parts of the training loop is incorrect. It jumps from point 5 to point 8. This should be corrected to 6 for clarity.

Suggested change
8. **Logging:** Periodically, average losses are printed.
6. **Logging:** Periodically, average losses are printed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants