Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,14 @@ def forward_decode(
layer.v_scale,
)

# In deterministic mode, use the unified 1-stage kernel so that decode
# and extend share the same sequential reduction order, enabling bit-wise
# alignment of log_probs between rollout (decode) and training (extend).
if self.enable_deterministic:
return self._forward_decode_unified(
q, o, layer, forward_batch, logits_soft_cap, sinks
)

if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
kv_indptr = self.forward_metadata.window_kv_indptr
kv_indices = self.forward_metadata.window_kv_indices
Expand Down Expand Up @@ -1109,6 +1117,83 @@ def forward_decode(
)
return o

def _forward_decode_unified(
self,
q: torch.Tensor,
o: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
logits_soft_cap: float,
sinks: Optional[torch.Tensor],
):
"""
Decode attention using the unified 1-stage kernel for bit-wise alignment
with the extend (training) path.

During decode, each sequence contributes exactly 1 query token. We treat
this as a degenerate extend call where:
- qo_indptr = [0, 1, 2, ..., bs] (one Q token per sequence)
- kv_indptr / kv_indices already built by init_forward_metadata for decode
- prefix_lens = full KV length per sequence (all KV slots are "prefix";
the current token was written to the cache before this call)
- max_len_extend = 1

This ensures the same sequential reduction order as the unified kernel
used in extend and in the training forward pass, achieving bit-wise
identical attention outputs and therefore identical log_probs.
"""
bs = forward_batch.batch_size

# One Q token per sequence.
qo_indptr = torch.arange(bs + 1, dtype=torch.int32, device=self.device)

# Reuse the kv_indptr / kv_indices already prepared for decode.
# These already include the token that was just written to the KV cache.
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
kv_indptr = self.forward_metadata.window_kv_indptr
kv_indices = self.forward_metadata.window_kv_indices
sliding_window_size = layer.sliding_window_size
window_kv_lens = kv_indptr[1 : bs + 1] - kv_indptr[:bs]
window_start_pos = forward_batch.seq_lens[:bs] - window_kv_lens - 1
else:
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices
sliding_window_size = -1
window_start_pos = None

if layer.k_scale is not None and layer.v_scale is not None:
k_descale = layer.k_scale_float
v_descale = layer.v_scale_float
else:
k_descale = 1.0
v_descale = 1.0

# All KV slots are prefix (no new extend tokens beyond the one already
# written to cache). prefix_lens equals the per-sequence KV length.
prefix_lens = (kv_indptr[1 : bs + 1] - kv_indptr[:bs]).to(torch.int32)

self.extend_attention_fwd_unified(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
k_descale,
v_descale,
qo_indptr,
kv_indptr,
kv_indices,
prefix_lens,
max_len_extend=1,
sm_scale=layer.scaling,
logit_cap=logits_soft_cap,
is_causal=True,
sliding_window_size=sliding_window_size,
sinks=sinks,
window_start_pos=window_start_pos,
xai_temperature_len=layer.xai_temperature_len,
)
return o


class TritonMultiStepDraftBackend:
"""
Expand Down
60 changes: 53 additions & 7 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def __init__(
) -> None:
super().__init__()
self.has_weight = has_weight
# When rl_on_policy_target is set, match HF's RMSNorm behavior:
# cast x to orig_dtype BEFORE multiplying with weight (like HF does).
if get_global_server_args().rl_on_policy_target is not None and not cast_x_before_out_mul:
cast_x_before_out_mul = True
self.cast_x_before_out_mul = cast_x_before_out_mul
self.fp32_residual = fp32_residual
self.override_orig_dtype = override_orig_dtype
Expand Down Expand Up @@ -169,6 +173,17 @@ def forward_aiter(
residual: Optional[torch.Tensor] = None,
post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if is_batch_invariant_mode_enabled():
if (
residual is not None
or get_global_server_args().rl_on_policy_target is not None
):
return self.forward_native(x, residual, post_residual_addition)
return rms_norm_batch_invariant(
x,
self.weight.data,
self.variance_epsilon,
)
if residual is not None:
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
Expand All @@ -195,6 +210,18 @@ def forward_hip(
if not _has_vllm_rms_norm:
return self.forward_native(x, residual, post_residual_addition)

if is_batch_invariant_mode_enabled():
if (
residual is not None
or get_global_server_args().rl_on_policy_target is not None
):
return self.forward_native(x, residual, post_residual_addition)
return rms_norm_batch_invariant(
x,
self.weight.data,
self.variance_epsilon,
)

if not x.is_contiguous():
# NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous()
Expand All @@ -220,15 +247,29 @@ def forward_native(
if not x.is_contiguous():
x = x.contiguous()
orig_dtype = self.override_orig_dtype or x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
if post_residual_addition is not None:
x = x + post_residual_addition.to(torch.float32)
if self.fp32_residual:
if (
not self.fp32_residual
and get_global_server_args().rl_on_policy_target is not None
):
# Match HF's behavior: add residual in orig_dtype (bf16) BEFORE
# upcasting to fp32 for norm. HF adds residual + attn_out in
# bf16, then passes to layernorm. SGLang's default upcasts both
# to fp32 before adding, which is more precise but produces
# different results from the training side.
x = x.to(orig_dtype) + residual.to(orig_dtype)
residual = x.clone()
x = x.to(torch.float32)
else:
residual = x.to(orig_dtype)
x = x.to(torch.float32) + residual.to(torch.float32)
if post_residual_addition is not None:
x = x + post_residual_addition.to(torch.float32)
if self.fp32_residual:
residual = x.clone()
else:
residual = x.to(orig_dtype)
else:
x = x.to(torch.float32)

hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
Expand All @@ -252,7 +293,12 @@ def forward_native(
x = x * torch.rsqrt(variance + self.variance_epsilon)

if self.cast_x_before_out_mul:
x = self.weight * x.to(orig_dtype)
if get_global_server_args().rl_on_policy_target is not None:
# Match HF: cast weight to orig_dtype too (weight may be fp32
# if caller set weight_dtype=torch.float32, but HF uses bf16).
x = self.weight.to(orig_dtype) * x.to(orig_dtype)
else:
x = self.weight * x.to(orig_dtype)
else:
x = (x * self.weight).to(orig_dtype)

Expand Down
33 changes: 22 additions & 11 deletions python/sglang/srt/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def __init__(

if get_global_server_args().rl_on_policy_target is not None:
self._forward_method = self.forward_native
self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
apply_rotary_emb
)
# NOTE: Do NOT torch.compile _apply_rotary_emb_wrapped — it can
# change numerical behavior and cause misalignment with HF's RoPE.
self.position_cos, self.position_sin = None, None

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
Expand All @@ -115,19 +114,31 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
/ self.rotary_dim
)
)
if get_global_server_args().rl_on_policy_target is not None:
inv_freq = inv_freq.cuda()
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)

freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if get_global_server_args().rl_on_policy_target is not None:
# Compute entirely on CPU to match HF's numerical behavior exactly.
# GPU float32 ops can produce slightly different results from CPU,
# causing bf16 rounding differences after cast.
t = torch.arange(
self.max_position_embeddings, dtype=torch.float, device="cpu"
)
inv_freq_cpu = inv_freq.cpu() if inv_freq.is_cuda else inv_freq
freqs = torch.einsum("i,j -> ij", t, inv_freq_cpu)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if torch.cuda.is_available():
cache = cache.to(torch.device("cuda"))
else:
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache

def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def forward(
# In RL on-policy mode, we use log_softmax to compute logprobs to match the trainer.
logprobs_via_logsoftmax_kernel = None
if self.rl_on_policy_target is not None:
# TODO: use more inplace ops to save memory
logits_div_temperature = (
logits.bfloat16().div(sampling_info.temperatures).bfloat16()
)
# Use fp32 log_softmax to match training-side computation exactly.
# The training side computes log_softmax in fp32; previous bf16
# casts here introduced a systematic ~5e-4 drift in logprob_abs_diff.
logits_div_temperature = logits.div(sampling_info.temperatures)
logprobs_via_logsoftmax_kernel = torch.log_softmax(
logits_div_temperature, dim=-1
)
Expand Down
37 changes: 23 additions & 14 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn

from sglang.srt.distributed import (
Expand Down Expand Up @@ -92,9 +93,18 @@ def __init__(

def forward(self, x):
if get_global_server_args().rl_on_policy_target is not None:
x = x.bfloat16()

gate_up, _ = self.gate_up_proj(x)
# Split into separate gate and up matmuls to match HF's MLP which
# does gate_proj(x) and up_proj(x) separately. A single merged
# matmul produces different floating-point results due to different
# accumulation order, causing logprob drift in on-policy training.
intermediate_size = self.gate_up_proj.output_size // 2
gate_weight = self.gate_up_proj.weight[:intermediate_size]
up_weight = self.gate_up_proj.weight[intermediate_size:]
gate = F.linear(x, gate_weight)
up = F.linear(x, up_weight)
gate_up = torch.cat([gate, up], dim=-1)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
Expand Down Expand Up @@ -280,11 +290,6 @@ def __init__(
quant_config=quant_config,
use_attn_tp_group=is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The params_dtype argument for VocabParallelEmbedding has been removed. Previously, it was set to torch.float32 when rl_on_policy_target was enabled, which is often desirable for matching training precision. By removing this, the embedding weights will likely default to the model's dtype (e.g., bfloat16), which could introduce numerical differences compared to a training setup that uses float32 for embeddings. This seems to contradict the PR's goal of achieving bit-wise identity with the training process. Was this removal intentional? If not, this could be a potential bug.

                prefix=add_prefix("embed_tokens", prefix),
                params_dtype=(
                    torch.float32
                    if get_global_server_args().rl_on_policy_target is not None
                    else None
                ),

params_dtype=(
torch.float32
if get_global_server_args().rl_on_policy_target is not None
else None
),
)
else:
self.embed_tokens = PPMissingLayer()
Expand All @@ -305,18 +310,22 @@ def __init__(
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
norm_kwargs = (
dict(
# For non-triton on-policy backends, preserve fp32 final norm.
_server_args = get_global_server_args()
if (
_server_args.rl_on_policy_target is not None
and _server_args.attention_backend != "triton"
):
_norm_kwargs = dict(
weight_dtype=torch.float32,
cast_x_before_out_mul=True,
override_orig_dtype=torch.float32,
fp32_residual=True,
)
if get_global_server_args().rl_on_policy_target is not None
else {}
)
else:
_norm_kwargs = {}
self.norm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
config.hidden_size, eps=config.rms_norm_eps, **_norm_kwargs
)
else:
self.norm = PPMissingLayer(return_tuple=True)
Expand Down
Loading
Loading