Skip to content

Commit 378fc20

Browse files
committed
Fuse nvfp4_quant for attention/mlp layers
Signed-off-by: Wanli Jiang <[email protected]>
1 parent e5caba2 commit 378fc20

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
2424
BaseWeightMapper
2525
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
26-
from tensorrt_llm._torch.utils import ActivationType, relu2
26+
from tensorrt_llm._torch.utils import ActivationType, Fp4QuantizedTensor, relu2
2727

2828
from ..attention_backend import AttentionMetadata
2929
from ..distributed import AllReduce
@@ -318,12 +318,29 @@ def __init__(
318318
self.layer_idx = layer_idx
319319
self.layer_type = layer_type
320320

321+
# Check if NVFP4 is enabled for this layer
322+
quant_config = model_config.get_quant_config()
323+
self.is_nvfp4 = (quant_config is not None
324+
and quant_config.layer_quant_mode.has_nvfp4())
325+
326+
# Determine if this layer can use fused RMSNorm + Add + Quantize
327+
# Mamba layers (M) have BF16 in_proj (excluded from FP4), cannot be fused
328+
# MLP (-) and Attention (*) layers have FP4 first linear, can be fused
329+
# MoE (E) layers need BF16 for gate and have different scales for shared/routed,
330+
# so input-side fusion doesn't provide net benefit (would require dequantization)
331+
self.is_nvfp4_fusable = self.is_nvfp4 and layer_type in ["-", "*"]
332+
321333
self.norm = RMSNorm(
322334
hidden_size=config.hidden_size,
323335
eps=config.rms_norm_eps,
324336
dtype=config.torch_dtype,
325337
)
326338

339+
# Enable NVFP4 mode on RMSNorm for fusable layers
340+
# This allows the fused_add_rms_norm_quant kernel to be used
341+
if self.is_nvfp4_fusable:
342+
self.norm.is_nvfp4 = True
343+
327344
if layer_type == "M":
328345
self.mixer = Mamba2Mixer(d_model=config.hidden_size,
329346
d_state=config.ssm_state_size,
@@ -351,6 +368,17 @@ def __init__(
351368
else:
352369
raise ValueError(f"{layer_type} is not supported")
353370

371+
# Cache reference to the module containing input_scale for NVFP4 fusion
372+
# This avoids repeated hasattr/getattr lookups in forward()
373+
self._nvfp4_input_scale_source = None
374+
if self.is_nvfp4_fusable:
375+
if hasattr(self.mixer, 'up_proj'):
376+
# MLP layers (-): first linear is up_proj
377+
self._nvfp4_input_scale_source = self.mixer.up_proj
378+
elif hasattr(self.mixer, 'qkv_proj'):
379+
# Attention layers (*): first linear is qkv_proj
380+
self._nvfp4_input_scale_source = self.mixer.qkv_proj
381+
354382
def forward(
355383
self,
356384
position_ids: torch.IntTensor,
@@ -363,6 +391,14 @@ def forward(
363391
**kwargs,
364392
) -> Tuple[torch.Tensor, torch.Tensor]:
365393

394+
# Set up NVFP4 fusion if this layer is fusable
395+
# This enables fused RMSNorm + Add + Quantize kernel
396+
if self._nvfp4_input_scale_source is not None:
397+
input_scale = getattr(self._nvfp4_input_scale_source, 'input_scale',
398+
None)
399+
if input_scale is not None:
400+
self.norm.nvfp4_scale = input_scale
401+
366402
if moe_separate_outputs is not None:
367403
# Previous layer was MOE - use fused add+add+rmsnorm
368404
routed, shared = moe_separate_outputs

0 commit comments

Comments
 (0)