-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add RLHF guide and dummy demo with Keras/JAX #2117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This commit introduces a new example for Reinforcement Learning from Human Feedback (RLHF). It includes: - \`examples/rl/rlhf_dummy_demo.py\`: A Python script demonstrating a simple RLHF loop with a dummy environment, a policy model, and a reward model, using Keras with the JAX backend. - \`examples/rl/md/rlhf_dummy_demo.md\`: A Markdown guide explaining the RLHF concept and the implementation details of the demo script. - \`examples/rl/README.md\`: A new README for the RL examples section, now including the RLHF demo. Note: The Python demo script (\`rlhf_dummy_demo.py\`) currently experiences timeout issues during the training loop in the development environment, even with significantly reduced computational load. This is documented in the guide and README. The code serves as a structural example of implementing the RLHF components.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Yasir! just keep the .py file for now. once it is approved we can generate the .ipynb files
examples/rl/rlhf_dummy_demo.py
Outdated
policy_model_params["non_trainable"], | ||
state_input | ||
) | ||
actual_predictions_tensor = predictions_tuple[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we assuming batch size is 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The code as written assumes a batch size of 1 for all model inputs and gradient calculations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a simplification for the demo's clarity and to manage complexity, especially since REINFORCE-style updates can be done on single trajectories. A more advanced setup would definitely use batching across multiple episodes or from a replay buffer for stability and efficiency.
Does that make sense in the context of this simplified demo?
examples/rl/rlhf_dummy_demo.py
Outdated
episode_policy_losses.append(current_policy_loss) | ||
policy_grads_step = policy_grads_dict_step["trainable"] | ||
# Accumulate policy gradients | ||
for i, grad in enumerate(policy_grads_step): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For potentially improving performance
policy_grads_accum = jax.tree_map(lambda acc, new: acc + new if new is not None else acc, policy_grads_accum, policy_grads_step)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have refactored policy gradients and reward gradients accumulations using jax.tree_map
.
examples/rl/md/rlhf_dummy_demo.md
Outdated
actual_predictions_tensor = predictions_tuple[0] | ||
action_probs = actual_predictions_tensor[0] | ||
log_prob = jnp.log(action_probs[action] + 1e-7) | ||
return -log_prob * predicted_reward_value_stopped |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this predicted_reward_value is just R(s,a), then it's using the immediate predicted reward.
it's a very naive and generally ineffective way to train a policy. Thee log_prob should be multiplied by the cumulative discounted future reward (Return, G_t).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commit c212593 addresses this.
…rewards. This commit refines the RLHF demo example (`examples/rl/rlhf_dummy_demo.py`) to use discounted cumulative actual rewards (G_t) for policy gradient calculations, aligning it with the REINFORCE algorithm. Changes include: - Added a `calculate_discounted_returns` helper function. - Modified the `rlhf_training_loop` to collect trajectories (states, actions, rewards) and compute G_t for each step at the end of an episode. - Updated the policy loss function to use these G_t values instead of immediate predicted rewards. - The reward model training logic remains focused on predicting immediate rewards based on simulated human feedback (environment reward in this demo). - Updated the corresponding RLHF guide (`examples/rl/md/rlhf_dummy_demo.md`) to explain these changes and provide updated code snippets. The timeout issues with the script in the development environment persist, but the code now better reflects a standard policy gradient approach.
the .md files are automatically generated. SO you might want to move the explanation content part to .py |
Moving the explanation piece to .py from .md
I have deleted the .md files and added relevant documentation pieces to .py file |
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new example demonstrating a simplified Reinforcement Learning from Human Feedback (RLHF) loop using Keras with a JAX backend. The example includes a Python script with an embedded markdown guide.
My review focuses on improving clarity, correctness, and maintainability. Key findings include:
- Inconsistent terminology (RLAIF vs. RLHF) that should be standardized.
- A critical discrepancy between the documentation and the code regarding the policy update algorithm.
- A bug in the logging logic that prevents any training progress from being displayed with the current settings.
- A minor performance issue in a helper function and a formatting error in the documentation.
I've provided specific suggestions to address these points.
# Reinforcement Learning from AI Feedback(RLAIF) - Demo Guide | ||
This guide explains the concept of Reinforcement Learning from AI Feedback (RLAIF) and walks through the components of the accompanying script `rlhf_demo.py`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The guide's main title and introduction use "RLAIF" (Reinforcement Learning from AI Feedback), but the PR and the rest of the guide refer to "RLHF" (Reinforcement Learning from Human Feedback). This is inconsistent and confusing. Please use "RLHF" consistently. The metadata in lines 2 and 6 should also be updated to reflect this change.
# Reinforcement Learning from AI Feedback(RLAIF) - Demo Guide | |
This guide explains the concept of Reinforcement Learning from AI Feedback (RLAIF) and walks through the components of the accompanying script `rlhf_demo.py`. | |
# Reinforcement Learning from Human Feedback (RLHF) - Demo Guide | |
This guide explains the concept of Reinforcement Learning from Human Feedback (RLHF) and walks through the components of the accompanying script `rlhf_demo.py`. |
total_reward_loss_avg += mean_episode_reward_loss | ||
loss_count_avg +=1 | ||
|
||
if (episode + 1) % 100 == 0 and loss_count_avg > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The training progress is logged only when (episode + 1) % 100 == 0
. However, the training loop is configured to run for only 10 episodes (both in the function default and the __main__
block call). As a result, no progress will be printed during the run. The logging interval should be adjusted to be useful for a short demo run.
if (episode + 1) % 100 == 0 and loss_count_avg > 0: | |
if (episode + 1) % 5 == 0 and loss_count_avg > 0: |
4. **Policy Model Training (REINFORCE-style):** | ||
* **At the end of each episode:** | ||
* The `calculate_discounted_returns` function is called with the list of `true_reward`s collected during the episode to compute the discounted cumulative reward (G_t) for each step. | ||
* These returns are typically normalized (subtract mean, divide by standard deviation) to stabilize training. | ||
* The code then iterates through each step `t` of the collected trajectory. | ||
* For each step, the `policy_loss_fn` is called. Its loss is calculated as `-log_prob(action_t) * G_t`. This means the update encourages actions that led to higher overall discounted future rewards. | ||
* Gradients for the policy model are computed for each step and accumulated across the entire episode. | ||
* **Gradient Application:** The accumulated policy gradients are averaged over the number of steps and applied to the policy model using its optimizer. This update rule aims to increase the probability of actions that lead to good long-term outcomes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation here describes a REINFORCE-style policy update where discounted returns (G_t
) are calculated at the end of an episode from the true_reward
s. However, the actual implementation in rlhf_training_loop
performs an update at each step using the predicted_reward_value
from the reward model as the signal for the policy update (see lines 192-208). The calculate_discounted_returns
function is defined but never called.
This is a significant discrepancy that misleads users about how the algorithm is implemented. The documentation should be updated to accurately reflect the code's actor-critic-like implementation, where the reward model's output is used as the advantage signal for the policy update.
def calculate_discounted_returns(rewards, gamma=0.99): | ||
returns = [] | ||
cumulative_return = 0 | ||
for r in reversed(rewards): | ||
cumulative_return = r + gamma * cumulative_return | ||
returns.insert(0, cumulative_return) | ||
return jnp.array(returns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of calculate_discounted_returns
uses returns.insert(0, ...)
inside a loop. Inserting at the beginning of a Python list is an O(n) operation, making the whole function O(n^2) where n is the number of rewards. While this function is not currently used, the PR description mentions timeout issues, so performance is a valid concern.
A more performant O(n) approach is to append to the list and then reverse it once at the end.
def calculate_discounted_returns(rewards, gamma=0.99): | |
returns = [] | |
cumulative_return = 0 | |
for r in reversed(rewards): | |
cumulative_return = r + gamma * cumulative_return | |
returns.insert(0, cumulative_return) | |
return jnp.array(returns) | |
def calculate_discounted_returns(rewards, gamma=0.99): | |
# A more performant O(n) approach is to append to the list and then reverse it. | |
returns = [] | |
cumulative_return = 0 | |
for r in reversed(rewards): | |
cumulative_return = r + gamma * cumulative_return | |
returns.append(cumulative_return) | |
return jnp.array(returns[::-1]) |
* **Gradient Application:** The accumulated policy gradients are averaged over the number of steps and applied to the policy model using its optimizer. This update rule aims to increase the probability of actions that lead to good long-term outcomes. | ||
5. **Logging:** Average policy and reward losses for the episode are printed periodically. | ||
The core idea of RLHF is still present: we have a reward model that *could* be trained from human preferences. However, the policy update mechanism has shifted. Instead of using the reward model's output directly as the advantage signal for each step (as in the previous version of the script), the policy now learns from the actual discounted returns experienced in the episode, which is a more standard RL approach when actual rewards (or good proxies like `true_reward` here) are available for the full trajectory. In a full RLHF system, `episode_true_rewards` might themselves be replaced or augmented by the reward model's predictions if no dense "true" reward exists. | ||
8. **Logging:** Periodically, average losses are printed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commit introduces a new example for Reinforcement Learning from Human Feedback (RLHF).
It includes:
Note: The Python demo script (`rlhf_dummy_demo.py`) currently experiences timeout issues during the training loop in the development environment, even with significantly reduced computational load. This is documented in the guide and README. The code serves as a structural example of implementing the RLHF components.