2323from tensorrt_llm ._torch .models .checkpoints .base_weight_mapper import \
2424 BaseWeightMapper
2525from tensorrt_llm ._torch .modules .mamba .mamba2_metadata import Mamba2Metadata
26- from tensorrt_llm ._torch .utils import ActivationType , relu2
26+ from tensorrt_llm ._torch .utils import ActivationType , Fp4QuantizedTensor , relu2
2727
2828from ..attention_backend import AttentionMetadata
2929from ..distributed import AllReduce
@@ -318,12 +318,29 @@ def __init__(
318318 self .layer_idx = layer_idx
319319 self .layer_type = layer_type
320320
321+ # Check if NVFP4 is enabled for this layer
322+ quant_config = model_config .get_quant_config ()
323+ self .is_nvfp4 = (quant_config is not None
324+ and quant_config .layer_quant_mode .has_nvfp4 ())
325+
326+ # Determine if this layer can use fused RMSNorm + Add + Quantize
327+ # Mamba layers (M) have BF16 in_proj (excluded from FP4), cannot be fused
328+ # MLP (-) and Attention (*) layers have FP4 first linear, can be fused
329+ # MoE (E) layers need BF16 for gate and have different scales for shared/routed,
330+ # so input-side fusion doesn't provide net benefit (would require dequantization)
331+ self .is_nvfp4_fusable = self .is_nvfp4 and layer_type in ["-" , "*" ]
332+
321333 self .norm = RMSNorm (
322334 hidden_size = config .hidden_size ,
323335 eps = config .rms_norm_eps ,
324336 dtype = config .torch_dtype ,
325337 )
326338
339+ # Enable NVFP4 mode on RMSNorm for fusable layers
340+ # This allows the fused_add_rms_norm_quant kernel to be used
341+ if self .is_nvfp4_fusable :
342+ self .norm .is_nvfp4 = True
343+
327344 if layer_type == "M" :
328345 self .mixer = Mamba2Mixer (d_model = config .hidden_size ,
329346 d_state = config .ssm_state_size ,
@@ -351,6 +368,17 @@ def __init__(
351368 else :
352369 raise ValueError (f"{ layer_type } is not supported" )
353370
371+ # Cache reference to the module containing input_scale for NVFP4 fusion
372+ # This avoids repeated hasattr/getattr lookups in forward()
373+ self ._nvfp4_input_scale_source = None
374+ if self .is_nvfp4_fusable :
375+ if hasattr (self .mixer , 'up_proj' ):
376+ # MLP layers (-): first linear is up_proj
377+ self ._nvfp4_input_scale_source = self .mixer .up_proj
378+ elif hasattr (self .mixer , 'qkv_proj' ):
379+ # Attention layers (*): first linear is qkv_proj
380+ self ._nvfp4_input_scale_source = self .mixer .qkv_proj
381+
354382 def forward (
355383 self ,
356384 position_ids : torch .IntTensor ,
@@ -363,6 +391,14 @@ def forward(
363391 ** kwargs ,
364392 ) -> Tuple [torch .Tensor , torch .Tensor ]:
365393
394+ # Set up NVFP4 fusion if this layer is fusable
395+ # This enables fused RMSNorm + Add + Quantize kernel
396+ if self ._nvfp4_input_scale_source is not None :
397+ input_scale = getattr (self ._nvfp4_input_scale_source , 'input_scale' ,
398+ None )
399+ if input_scale is not None :
400+ self .norm .nvfp4_scale = input_scale
401+
366402 if moe_separate_outputs is not None :
367403 # Previous layer was MOE - use fused add+add+rmsnorm
368404 routed , shared = moe_separate_outputs
0 commit comments