11import logging
2+ import math
23from typing import Optional , Tuple , Union
34
5+ import torch
46from torch .fx .node import Target
57from torch_tensorrt .dynamo ._SourceIR import SourceIR
68from torch_tensorrt .dynamo .conversion import impl
1618
1719_LOGGER : logging .Logger = logging .getLogger (__name__ )
1820
21+ # FP8 E4M3 max representable magnitude. Softmax output is bounded to [0, 1],
22+ # so 1/448 saturates exactly at 1.0 and is data-independent (no calibration needed).
23+ _FP8_E4M3_MAX = 448.0
24+
25+
26+ def _maybe_set_fp8_softmax (
27+ ctx : ConversionContext ,
28+ name : str ,
29+ attention_layer : trt .IAttention ,
30+ ) -> bool :
31+ """Set FP8 softmax normalization quantization on the IAttention layer if the current
32+ node was annotated with a softmax FP8 scale by the fp8_attention_softmax lowering pass.
33+
34+ Returns True if FP8 normalization was configured (caller must set decomposable=False)."""
35+ if ctx .current_node is None :
36+ return False
37+ scale_val = ctx .current_node .meta .get ("_fp8_softmax_scale" )
38+ if scale_val is None :
39+ return False
40+ scale_tensor = get_trt_tensor (
41+ ctx ,
42+ torch .tensor (scale_val , dtype = torch .float32 ),
43+ name + "_softmax_fp8_scale" ,
44+ dtype = torch .float32 ,
45+ )
46+ attention_layer .normalization_quantize_to_type = trt .DataType .FP8
47+ attention_layer .normalization_quantize_scale = scale_tensor
48+ return True
49+
1950
2051def tril (
2152 ctx : ConversionContext ,
@@ -164,6 +195,18 @@ def scaled_dot_product_attention(
164195 Returns:
165196 TRTTensor: Attention output tensor with shape [batch, heads, seq_len, head_dim]
166197 """
198+ # When FP8 softmax normalization is active (modelopt FP8 MHA pattern) TRT's
199+ # FP8 MHA fusion requires the Q/DQ output to feed IAttention via a single
200+ # same-dtype Mul; any HALF<->FLOAT cast inserted by the default dynamic
201+ # 1/sqrt(D) computation breaks the fusion. Use a static same-dtype scalar
202+ # scale computed from the concrete head_dim.
203+ fp8_norm_active = (
204+ ctx .current_node is not None
205+ and ctx .current_node .meta .get ("_fp8_softmax_scale" ) is not None
206+ )
207+ if fp8_norm_active and scale is None and isinstance (query .shape [- 1 ], int ):
208+ scale = 1.0 / math .sqrt (query .shape [- 1 ])
209+
167210 if scale is None :
168211 # 1 / math.sqrt(query.size(-1))
169212 q_dim = impl .shape .shape (ctx , target , source_ir , f"{ name } _shape_q" , query , - 1 )
@@ -256,7 +299,8 @@ def scaled_dot_product_attention(
256299
257300 if mask_tensor is not None :
258301 attention_layer .mask = mask_tensor
259- attention_layer .decomposable = True
302+ fp8_norm = _maybe_set_fp8_softmax (ctx , name , attention_layer )
303+ attention_layer .decomposable = not fp8_norm
260304 attention_output = attention_layer .get_output (0 )
261305 return attention_output
262306
@@ -284,6 +328,13 @@ def scaled_dot_product_flash_attention(
284328 Optional [TRTTensor ],
285329 Optional [TRTTensor ],
286330]:
331+ fp8_norm_active = (
332+ ctx .current_node is not None
333+ and ctx .current_node .meta .get ("_fp8_softmax_scale" ) is not None
334+ )
335+ if fp8_norm_active and scale is None and isinstance (query .shape [- 1 ], int ):
336+ scale = 1.0 / math .sqrt (query .shape [- 1 ])
337+
287338 if scale is None :
288339 # 1 / math.sqrt(query.size(-1))
289340 q_dim = impl .shape .shape (ctx , target , source_ir , f"{ name } _shape_q" , query , - 1 )
@@ -314,7 +365,8 @@ def scaled_dot_product_flash_attention(
314365 )
315366 assert attention_layer is not None , "attention layer is None"
316367
317- attention_layer .decomposable = True
368+ fp8_norm = _maybe_set_fp8_softmax (ctx , name , attention_layer )
369+ attention_layer .decomposable = not fp8_norm
318370
319371 attention_output = attention_layer .get_output (0 )
320372 return attention_output , None , None , None , 0.0 , 0.0 , None , None , None
@@ -334,6 +386,13 @@ def scaled_dot_product_efficient_attention(
334386 is_causal : bool = False ,
335387 scale : Optional [float ] = None ,
336388) -> Tuple [TRTTensor , Optional [TRTTensor ], Optional [TRTTensor ], Optional [TRTTensor ]]:
389+ fp8_norm_active = (
390+ ctx .current_node is not None
391+ and ctx .current_node .meta .get ("_fp8_softmax_scale" ) is not None
392+ )
393+ if fp8_norm_active and scale is None and isinstance (query .shape [- 1 ], int ):
394+ scale = 1.0 / math .sqrt (query .shape [- 1 ])
395+
337396 if scale is None :
338397 # 1 / math.sqrt(query.size(-1))
339398 q_dim = impl .shape .shape (ctx , target , source_ir , f"{ name } _shape_q" , query , - 1 )
@@ -450,7 +509,8 @@ def scaled_dot_product_efficient_attention(
450509 if mask_tensor is not None :
451510 attention_layer .mask = mask_tensor
452511
453- attention_layer .decomposable = True
512+ fp8_norm = _maybe_set_fp8_softmax (ctx , name , attention_layer )
513+ attention_layer .decomposable = not fp8_norm
454514
455515 attention_output = attention_layer .get_output (0 )
456516 return attention_output , None , None , None
0 commit comments