[trainer] fix: update TorchTitanEngine for latest torchtitan API#6231
[trainer] fix: update TorchTitanEngine for latest torchtitan API#6231acisseJZhong wants to merge 1 commit intoverl-project:mainfrom
Conversation
|
|
There was a problem hiding this comment.
Code Review
This pull request updates the torchtitan engine implementation and its associated E2E test script. Key changes include refactoring model specification and attention backend handling, removing expert tensor parallelism configurations, and adding a placeholder loss function for Trainer initialization. The model flavor derivation logic was also enhanced to support callable configurations and different layer attribute names. I have no feedback to provide as there were no review comments to assess.
ba5beaf to
451c9dc
Compare
Align verl's TorchTitanEngine with torchtitan HEAD, fixing several breaking API changes and the attn_type bug reported in verl-project#6182. torchtitan API updates: - Remove `expert_tensor_parallel_degree` from ParallelismConfig (removed upstream) - Remove `etp` from ParallelDims constructor (removed upstream) - Remove `maybe_enable_amp` context manager (removed upstream; `train_context()` handles mixed precision) - Add `loss=CrossEntropyLoss.Config()` to Trainer.Config (BaseLoss is now abstract) attn_type fixes (verl-project#6182): - Pass `attn_backend=` to `model_registry()` instead of dead post-hoc override - Fix `attn_type` lookup to use `self.engine_config.attn_type` instead of wrong path derive_torchtitan_name_and_flavor fixes: - Handle config factories (callables) in addition to config objects - Fall back to `len(config.layers)` when `n_layers` attr doesn't exist test script fixes: - Use `NUM_GPUS` for rollout TP size instead of hardcoded 8 - Fix misleading default experiment name
451c9dc to
48abba3
Compare
Summary
expert_tensor_parallel_degree/etp/maybe_enable_amp, add concreteCrossEntropyLoss.Configfor new abstractBaseLossattn_typebeing silently ignored ([trainer] bug: TorchtitanEngine silently ignores attn_type="flex" — no clear BKM for which torchtitan version to use #6182): passattn_backend=tomodel_registry()and fix wrong attribute path inprepare_model_inputsderive_torchtitan_name_and_flavorto handle config factories (callables) andlen(layers)fallback for layer count matchingNUM_GPUSfor rollout TP size instead of hardcoded 8, fix misleading experiment nameFixes #6182
Test plan
derive_torchtitan_name_and_flavorcorrectly resolves Qwen3-0.6B flavortests/special_e2e/run_ppo_trainer_torchtitan.shwith torchtitan HEADattn_type=flex(FlexAttention) as configured