[AMD] feat: true on policy for triton backend#20108
[AMD] feat: true on policy for triton backend#20108XinyuJiangCMU wants to merge 2 commits intosgl-project:mainfrom
Conversation
…ining Ensure bit-wise identical log_probs between SGLang's decode/extend paths and HF training, enabling true on-policy RL (GRPO/PPO) without logprob drift. Key changes: 1. triton_backend.py: add `_forward_decode_unified` - when `--enable-deterministic-inference` is set, route decode through the same unified 1-stage Triton kernel used in extend. This aligns the sequential reduction order across decode (rollout) and extend (training), producing bit-wise identical attention outputs and therefore identical log_probs. 2. sampler.py: use fp32 log_softmax for rl_on_policy_target mode. Previous bfloat16 casts introduced a systematic ~5e-4 logprob drift vs the fp32 training-side log_softmax. 3. layernorm.py: centralize rl_on_policy_target RMSNorm behavior. Auto-enable `cast_x_before_out_mul=True` (matching HF's weight*x order), override to `forward_native` (skip fused kernels), and add residual in orig_dtype (bf16) before upcasting to match HF's layernorm residual path. 4. qwen2.py: split fused gate_up matmul into separate gate/up projections for rl_on_policy_target mode, matching HF's Qwen2/3MLP accumulation order. Remove fp32 embed_tokens and fp32 final norm overrides (now handled by layernorm.py centrally). 5. qwen3.py: remove scattered fp32 norm_kwargs and redundant dtype casts (now handled centrally in layernorm.py). Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com>
Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a series of precision-focused adjustments across several core components to enable true on-policy Reinforcement Learning (RL) for text models on AMD GPUs using the Triton attention backend. The primary goal is to achieve bit-wise identical log-probabilities between the SGLang rollout (inference) and FSDP training (extend) phases, which is crucial for stable and accurate on-policy RL. These changes involve aligning numerical operations with Hugging Face's implementations, particularly concerning attention, rotary embeddings, layer normalization, and sampler behavior. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive set of changes to achieve bit-wise numerical identity for on-policy reinforcement learning on AMD GPUs with the Triton backend. The modifications span across attention, rotary embeddings, layer normalization, and model-specific MLP implementations, ensuring that inference-time computations precisely match the training-time equivalents. The changes are well-commented and appear to be well-reasoned. I have one concern regarding a change in the embedding layer's data type that might affect numerical precision.
Note: Security Review did not run due to the size of the PR.
| @@ -280,11 +290,6 @@ def __init__( | |||
| quant_config=quant_config, | |||
| use_attn_tp_group=is_dp_attention_enabled(), | |||
| prefix=add_prefix("embed_tokens", prefix), | |||
There was a problem hiding this comment.
The params_dtype argument for VocabParallelEmbedding has been removed. Previously, it was set to torch.float32 when rl_on_policy_target was enabled, which is often desirable for matching training precision. By removing this, the embedding weights will likely default to the model's dtype (e.g., bfloat16), which could introduce numerical differences compared to a training setup that uses float32 for embeddings. This seems to contradict the PR's goal of achieving bit-wise identity with the training process. Was this removal intentional? If not, this could be a potential bug.
prefix=add_prefix("embed_tokens", prefix),
params_dtype=(
torch.float32
if get_global_server_args().rl_on_policy_target is not None
else None
),
Motivation
Support true on-policy RL for text models on AMD GPU with triton attention backend.
For bit-wise identical log-probs between sglang rollout (decode) and FSDP training (extend), the inference and training must execute exactly the same numerical operations.
Modifications
When
get_global_server_args().rl_on_policy_target is not Noneand using triton attention backend on AMD:triton_backend.py: In deterministic mode, route decode through
_forward_decode_unified()which uses the same 1-stageextend_attention_fwd_unifiedkernel as the extend path — ensures bit-wise identical attention computation between decode and extend.rotary_embedding/base.py:
torch.compileon_apply_rotary_emb_wrapped(compiled version produces different numerics)_compute_inv_freqwas movinginv_freqto CUDA, but_compute_cos_sin_cachestill builtton CPU. Added separate CPU path forrl_on_policy_targetto match HF's RoPE numerics exactly.layernorm.py:
cast_x_before_out_mul=Trueto match HF RMSNorm (cast x to orig_dtype before multiplying with weight)forward_nativeto match HF behaviorforward_aiter/forward_hip: whenis_batch_invariant_mode_enabled()andrl_on_policy_target is not None, fall back toforward_nativesampler.py: Use fp32
log_softmaxinstead of bf16 to match training-side log-prob computation.qwen2.py / qwen3.py: Split fused gate/up matmul into separate
F.linearcalls to match HF accumulation order; gate additional norm/dtype kwargs onattention_backend != "triton"to preserve original non-triton paths.Accuracy Tests
Verified using
run_simple_amd_triton.pyfrom the miles RL training framework, on Qwen3-0.6B (single AMD GPU, triton attention backend,rl_on_policy_target=fsdp):train_rollout_logprob_abs_diff: 0.0
Checklist
Collaborators
@XinyuJiangCMU
@JessicaJiang-123