Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Pull request overview
This PR adds partial support for GRPO (Group Relative Policy Optimization) training on Neuron (Trainium) devices through the new NeuronGRPOTrainer class. The implementation includes XLA-specific optimizations and modifications to work with the Torch XLA backend, though several core features remain unimplemented (vLLM integration, weight synchronization, tensor parallelism).
Changes:
- Adds
NeuronGRPOTrainerwith XLA-optimized implementations for generation, scoring, and loss computation - Introduces
NeuronGRPOConfigfor configuration with experimental flag requirement - Implements XLA-friendly utility functions (padding, entropy, statistical operations) in
trl_utils.py - Adds custom vLLM client implementations with CPU communicator and mock client for testing
- Updates
NeuronTrainerto support_prepare_inputshook and replacesxm.mark_step()withtorch_xla.sync() - Modifies LoRA transformation utilities to handle missing weights more gracefully
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
| optimum/neuron/trainers/grpo_trainer.py | Core GRPO trainer implementation with XLA optimizations (1414 lines, new file) |
| optimum/neuron/trainers/grpo_config.py | Configuration class with validation and experimental flag (118 lines, new file) |
| optimum/neuron/trainers/trl_utils.py | XLA-optimized utility functions for padding, statistics, and sampling (270 lines) |
| optimum/neuron/trainers/extras/vllm_client.py | Custom vLLM clients for Neuron with CPU communicator and mock implementation (213 lines, new file) |
| optimum/neuron/trainers/transformers.py | Updates to NeuronTrainer for _prepare_inputs hook and torch_xla.sync() migration |
| optimum/neuron/trainers/utils.py | Adds move_inputs_to_device utility and updates XLAPrefetchIterator |
| optimum/neuron/models/training/transformations_utils.py | Converts LoRA weight errors to silent skips for flexibility |
| optimum/neuron/trainers/metrics/collector.py | Refactors get_metric_unit for cleaner logic |
| optimum/neuron/utils/init.py | Exports is_vllm_available function |
| optimum/neuron/init.py | Exports NeuronGRPOTrainer and NeuronGRPOConfig |
| .github/actions/install_optimum_neuron/action.yml | Adds training extras to CI installation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if to_concat_and_duplicate_name is None or to_unfuse_name is None: | ||
| raise ValueError( | ||
| f"Could not find LoRA weights for {module_fully_qualified_name} with param name {param_name}." | ||
| ) | ||
| continue |
There was a problem hiding this comment.
Similar to the previous issue, this converts a hard error into a silent skip. This could hide configuration problems. Consider logging when weights are not found to aid debugging.
dacorvo
left a comment
There was a problem hiding this comment.
No more blockers, but will review in more details tomorrow. Copilot detected some issues that may be considered.
…a weight conversion
JingyaHuang
left a comment
There was a problem hiding this comment.
nothing to add from my side, if @tengomucho is ok with it as well.
|
@michaelbenayoun can you rebase or merge the main branch, it looks like you are using the non-working sanity changes that @tengomucho fixed in #1077. As a consequence, the training tests are not launched. |
tengomucho
left a comment
There was a problem hiding this comment.
It is hard for me to follow exactly what this allows, so I will have to trust you on this.
I just ask for the sanity checks to be run first, as mentioned by @dacorvo
Generally speaking, if you add a code feature, it would be better to add a test for it, don't you think?
|
|
||
| self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) | ||
| xm.mark_step() | ||
| torch_xla.sync() |
There was a problem hiding this comment.
I think mark_step is deprecated, so this is a good change, but don't you think it would be better to change all occurrences of this?
There was a problem hiding this comment.
You mean in the whole library?
In tests it is still existing, but in the library, the only occurence is in `optimum/neuron/generation/utils.py, I will update it.
For the test we will do it once it's fully covered because for now the code runs with a CPU vLLM env. |
|
This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
|
For now not merging with the recent progress in Torch Native. |
What does this PR do?
This PR adds partial support for GRPO.
It was broken down into smaller PRs:
optimum/neuron/accelerate#1042It adds the
NeuronGRPOTrainerwith a set of optimizations and modifications for the Torch XLA backend used to run things on Trainium instances. There are still core missing features:NeuronGRPOTrainer <-> vLLM