Skip to content

[AMD] feat: true on policy for triton backend#20108

Open
XinyuJiangCMU wants to merge 2 commits intosgl-project:mainfrom
XinyuJiangCMU:pr/on-policy-numerical-fixes
Open

[AMD] feat: true on policy for triton backend#20108
XinyuJiangCMU wants to merge 2 commits intosgl-project:mainfrom
XinyuJiangCMU:pr/on-policy-numerical-fixes

Conversation

@XinyuJiangCMU
Copy link

@XinyuJiangCMU XinyuJiangCMU commented Mar 8, 2026

Motivation

Support true on-policy RL for text models on AMD GPU with triton attention backend.

For bit-wise identical log-probs between sglang rollout (decode) and FSDP training (extend), the inference and training must execute exactly the same numerical operations.

Modifications

When get_global_server_args().rl_on_policy_target is not None and using triton attention backend on AMD:

  1. triton_backend.py: In deterministic mode, route decode through _forward_decode_unified() which uses the same 1-stage extend_attention_fwd_unified kernel as the extend path — ensures bit-wise identical attention computation between decode and extend.

  2. rotary_embedding/base.py:

    • Disable torch.compile on _apply_rotary_emb_wrapped (compiled version produces different numerics)
    • Fix device mismatch bug: _compute_inv_freq was moving inv_freq to CUDA, but _compute_cos_sin_cache still built t on CPU. Added separate CPU path for rl_on_policy_target to match HF's RoPE numerics exactly.
  3. layernorm.py:

    • Auto-set cast_x_before_out_mul=True to match HF RMSNorm (cast x to orig_dtype before multiplying with weight)
    • Use bf16 residual addition (not fp32) in forward_native to match HF behavior
    • In forward_aiter / forward_hip: when is_batch_invariant_mode_enabled() and rl_on_policy_target is not None, fall back to forward_native
  4. sampler.py: Use fp32 log_softmax instead of bf16 to match training-side log-prob computation.

  5. qwen2.py / qwen3.py: Split fused gate/up matmul into separate F.linear calls to match HF accumulation order; gate additional norm/dtype kwargs on attention_backend != "triton" to preserve original non-triton paths.

Accuracy Tests

Verified using run_simple_amd_triton.py from the miles RL training framework, on Qwen3-0.6B (single AMD GPU, triton attention backend, rl_on_policy_target=fsdp):

train_rollout_logprob_abs_diff: 0.0

logprob_diff

Checklist

Collaborators

@XinyuJiangCMU
@JessicaJiang-123

XinyuJiangCMU and others added 2 commits March 8, 2026 03:34
…ining

Ensure bit-wise identical log_probs between SGLang's decode/extend paths
and HF training, enabling true on-policy RL (GRPO/PPO) without logprob drift.

Key changes:

1. triton_backend.py: add `_forward_decode_unified` - when
   `--enable-deterministic-inference` is set, route decode through the same
   unified 1-stage Triton kernel used in extend. This aligns the sequential
   reduction order across decode (rollout) and extend (training), producing
   bit-wise identical attention outputs and therefore identical log_probs.

2. sampler.py: use fp32 log_softmax for rl_on_policy_target mode. Previous
   bfloat16 casts introduced a systematic ~5e-4 logprob drift vs the fp32
   training-side log_softmax.

3. layernorm.py: centralize rl_on_policy_target RMSNorm behavior. Auto-enable
   `cast_x_before_out_mul=True` (matching HF's weight*x order), override to
   `forward_native` (skip fused kernels), and add residual in orig_dtype
   (bf16) before upcasting to match HF's layernorm residual path.

4. qwen2.py: split fused gate_up matmul into separate gate/up projections
   for rl_on_policy_target mode, matching HF's Qwen2/3MLP accumulation order.
   Remove fp32 embed_tokens and fp32 final norm overrides (now handled by
   layernorm.py centrally).

5. qwen3.py: remove scattered fp32 norm_kwargs and redundant dtype casts
   (now handled centrally in layernorm.py).

Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com>
Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a series of precision-focused adjustments across several core components to enable true on-policy Reinforcement Learning (RL) for text models on AMD GPUs using the Triton attention backend. The primary goal is to achieve bit-wise identical log-probabilities between the SGLang rollout (inference) and FSDP training (extend) phases, which is crucial for stable and accurate on-policy RL. These changes involve aligning numerical operations with Hugging Face's implementations, particularly concerning attention, rotary embeddings, layer normalization, and sampler behavior.

Highlights

  • Triton Attention Backend: Implemented a unified 1-stage kernel for decode attention in deterministic mode to ensure bit-wise identical log-probabilities between rollout (decode) and training (extend) paths.
  • Rotary Embedding Numerical Stability: Disabled torch.compile for rotary embedding application and ensured cos_sin_cache computation on CPU to precisely match Hugging Face's RoPE numerics, preventing numerical drift.
  • Layer Normalization Adjustments: Modified RMSNorm to automatically set cast_x_before_out_mul=True, use bf16 residual addition, and conditionally fall back to forward_native to align with Hugging Face's RMSNorm behavior for on-policy RL.
  • Sampler Log-Softmax Precision: Updated the sampler to use FP32 precision for log_softmax calculations in on-policy RL mode, matching the training-side computation and eliminating systematic log-probability drift.
  • Qwen MLP Fused Operations: Refactored Qwen2 and Qwen3 MLP's fused gate/up projections into separate F.linear calls to match Hugging Face's accumulation order, resolving log-probability discrepancies in on-policy training.
  • Conditional Norm and Dtype Handling: Introduced conditional logic for norm parameters and bfloat16 casting in Qwen2 and Qwen3 models, specifically for non-Triton on-policy backends, to maintain numerical precision.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/triton_backend.py
    • Introduced a new private method _forward_decode_unified to handle decode attention using a unified 1-stage kernel.
    • Modified forward_decode to conditionally call _forward_decode_unified when self.enable_deterministic is true, ensuring bit-wise alignment with the extend path.
  • python/sglang/srt/layers/layernorm.py
    • Updated __init__ to automatically set cast_x_before_out_mul=True if rl_on_policy_target is enabled.
    • Modified forward_aiter and forward_hip to fall back to forward_native when is_batch_invariant_mode_enabled() and rl_on_policy_target is active or residual is present.
    • Adjusted forward_native to conditionally add residual in orig_dtype (bf16) before upcasting to fp32 when rl_on_policy_target is active and fp32_residual is false.
    • Modified forward_native to cast the weight to orig_dtype before multiplication when rl_on_policy_target is active and cast_x_before_out_mul is true.
  • python/sglang/srt/layers/rotary_embedding/base.py
    • Removed torch.compile from _apply_rotary_emb_wrapped to prevent numerical differences.
    • Modified _compute_cos_sin_cache to compute t and inv_freq on CPU when rl_on_policy_target is enabled, then move to CUDA, to match HF's RoPE numerics.
  • python/sglang/srt/layers/sampler.py
    • Changed the log_softmax computation for rl_on_policy_target to use fp32 instead of bf16 for logits_div_temperature.
  • python/sglang/srt/models/qwen2.py
    • Imported torch.nn.functional as F.
    • Modified the forward method in Qwen2MLP to split the fused gate_up_proj into separate F.linear calls for gate and up when rl_on_policy_target is enabled.
    • Removed the conditional params_dtype setting for embed_tokens initialization.
    • Adjusted norm_kwargs for the final RMSNorm to be conditional on rl_on_policy_target AND attention_backend != "triton".
  • python/sglang/srt/models/qwen3.py
    • Modified __init__ to conditionally set _qk_norm_kwargs for q_norm and k_norm based on rl_on_policy_target and attention_backend != "triton".
    • Added conditional bfloat16 casting for hidden_states, q, and k in the forward method, specifically for non-Triton on-policy backends.
Activity
  • No human activity (comments, reviews, or progress updates) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive set of changes to achieve bit-wise numerical identity for on-policy reinforcement learning on AMD GPUs with the Triton backend. The modifications span across attention, rotary embeddings, layer normalization, and model-specific MLP implementations, ensuring that inference-time computations precisely match the training-time equivalents. The changes are well-commented and appear to be well-reasoned. I have one concern regarding a change in the embedding layer's data type that might affect numerical precision.

Note: Security Review did not run due to the size of the PR.

@@ -280,11 +290,6 @@ def __init__(
quant_config=quant_config,
use_attn_tp_group=is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The params_dtype argument for VocabParallelEmbedding has been removed. Previously, it was set to torch.float32 when rl_on_policy_target was enabled, which is often desirable for matching training precision. By removing this, the embedding weights will likely default to the model's dtype (e.g., bfloat16), which could introduce numerical differences compared to a training setup that uses float32 for embeddings. This seems to contradict the PR's goal of achieving bit-wise identity with the training process. Was this removal intentional? If not, this could be a potential bug.

                prefix=add_prefix("embed_tokens", prefix),
                params_dtype=(
                    torch.float32
                    if get_global_server_args().rl_on_policy_target is not None
                    else None
                ),

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant