Skip to content

GRPO Trainer#1020

Open
michaelbenayoun wants to merge 102 commits intomainfrom
grpo
Open

GRPO Trainer#1020
michaelbenayoun wants to merge 102 commits intomainfrom
grpo

Conversation

@michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Nov 4, 2025

What does this PR do?

This PR adds partial support for GRPO.

It was broken down into smaller PRs:

It adds the NeuronGRPOTrainer with a set of optimizations and modifications for the Torch XLA backend used to run things on Trainium instances. There are still core missing features:

  • Integration with vLLM: we use a custom CPU vLLM hack for now. The plan is to work on the vLLM part on another PR.
  • Weight Synchronization NeuronGRPOTrainer <-> vLLM
  • No tensor parallelism

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds partial support for GRPO (Group Relative Policy Optimization) training on Neuron (Trainium) devices through the new NeuronGRPOTrainer class. The implementation includes XLA-specific optimizations and modifications to work with the Torch XLA backend, though several core features remain unimplemented (vLLM integration, weight synchronization, tensor parallelism).

Changes:

  • Adds NeuronGRPOTrainer with XLA-optimized implementations for generation, scoring, and loss computation
  • Introduces NeuronGRPOConfig for configuration with experimental flag requirement
  • Implements XLA-friendly utility functions (padding, entropy, statistical operations) in trl_utils.py
  • Adds custom vLLM client implementations with CPU communicator and mock client for testing
  • Updates NeuronTrainer to support _prepare_inputs hook and replaces xm.mark_step() with torch_xla.sync()
  • Modifies LoRA transformation utilities to handle missing weights more gracefully

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 19 comments.

Show a summary per file
File Description
optimum/neuron/trainers/grpo_trainer.py Core GRPO trainer implementation with XLA optimizations (1414 lines, new file)
optimum/neuron/trainers/grpo_config.py Configuration class with validation and experimental flag (118 lines, new file)
optimum/neuron/trainers/trl_utils.py XLA-optimized utility functions for padding, statistics, and sampling (270 lines)
optimum/neuron/trainers/extras/vllm_client.py Custom vLLM clients for Neuron with CPU communicator and mock implementation (213 lines, new file)
optimum/neuron/trainers/transformers.py Updates to NeuronTrainer for _prepare_inputs hook and torch_xla.sync() migration
optimum/neuron/trainers/utils.py Adds move_inputs_to_device utility and updates XLAPrefetchIterator
optimum/neuron/models/training/transformations_utils.py Converts LoRA weight errors to silent skips for flexibility
optimum/neuron/trainers/metrics/collector.py Refactors get_metric_unit for cleaner logic
optimum/neuron/utils/init.py Exports is_vllm_available function
optimum/neuron/init.py Exports NeuronGRPOTrainer and NeuronGRPOConfig
.github/actions/install_optimum_neuron/action.yml Adds training extras to CI installation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 703 to +704
if to_concat_and_duplicate_name is None or to_unfuse_name is None:
raise ValueError(
f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}."
)
continue
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous issue, this converts a hard error into a silent skip. This could hide configuration problems. Consider logging when weights are not found to aid debugging.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more blockers, but will review in more details tomorrow. Copilot detected some issues that may be considered.

@dacorvo dacorvo dismissed their stale review February 5, 2026 18:38

Blocker addressed

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nothing to add from my side, if @tengomucho is ok with it as well.

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 12, 2026

@michaelbenayoun can you rebase or merge the main branch, it looks like you are using the non-working sanity changes that @tengomucho fixed in #1077. As a consequence, the training tests are not launched.

Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hard for me to follow exactly what this allows, so I will have to trust you on this.
I just ask for the sanity checks to be run first, as mentioned by @dacorvo
Generally speaking, if you add a code feature, it would be better to add a test for it, don't you think?


self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
xm.mark_step()
torch_xla.sync()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mark_step is deprecated, so this is a good change, but don't you think it would be better to change all occurrences of this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean in the whole library?

In tests it is still existing, but in the library, the only occurence is in `optimum/neuron/generation/utils.py, I will update it.

@michaelbenayoun
Copy link
Member Author

michaelbenayoun commented Feb 13, 2026

It is hard for me to follow exactly what this allows, so I will have to trust you on this. I just ask for the sanity checks to be run first, as mentioned by @dacorvo Generally speaking, if you add a code feature, it would be better to add a test for it, don't you think?

For the test we will do it once it's fully covered because for now the code runs with a CPU vLLM env.
Also in the sub PRs we have added some tests.

@github-actions
Copy link

github-actions bot commented Mar 1, 2026

This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Mar 1, 2026
@michaelbenayoun
Copy link
Member Author

For now not merging with the recent progress in Torch Native.

@github-actions github-actions bot removed the Stale label Mar 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants