Skip to content

Commit 6c27bf5

Browse files
committed
feat: Adding the scale normalize flag in IAttention layer using a pass to annotate target nodes
1 parent 9200829 commit 6c27bf5

8 files changed

Lines changed: 451 additions & 129 deletions

File tree

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass, field
2+
from typing import Optional
23

34
import torch
5+
import torch.fx
46
from torch_tensorrt.dynamo._settings import CompilationSettings
57
from torch_tensorrt.dynamo.types import TRTNetwork
68

@@ -25,6 +27,7 @@ class ConversionContext:
2527
requires_native_multidevice: bool = False
2628
weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict)
2729
cpu_weights_reference_holder: list[torch.Tensor] = field(default_factory=list)
30+
current_node: Optional[torch.fx.Node] = field(default=None)
2831

2932
def record_weight(self, name: str, weight: torch.Tensor) -> None:
3033
"""

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
841841
self.ctx.requires_native_multidevice = True
842842
_LOGGER.debug(f"{target} requires native multi-device support")
843843

844+
self.ctx.current_node = self._cur_node
844845
if calling_convention is CallingConvention.LEGACY:
845846
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
846847
else:

py/torch_tensorrt/dynamo/conversion/impl/attention.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
2+
import math
23
from typing import Optional, Tuple, Union
34

5+
import torch
46
from torch.fx.node import Target
57
from torch_tensorrt.dynamo._SourceIR import SourceIR
68
from torch_tensorrt.dynamo.conversion import impl
@@ -16,6 +18,35 @@
1618

1719
_LOGGER: logging.Logger = logging.getLogger(__name__)
1820

21+
# FP8 E4M3 max representable magnitude. Softmax output is bounded to [0, 1],
22+
# so 1/448 saturates exactly at 1.0 and is data-independent (no calibration needed).
23+
_FP8_E4M3_MAX = 448.0
24+
25+
26+
def _maybe_set_fp8_softmax(
27+
ctx: ConversionContext,
28+
name: str,
29+
attention_layer: trt.IAttention,
30+
) -> bool:
31+
"""Set FP8 softmax normalization quantization on the IAttention layer if the current
32+
node was annotated with a softmax FP8 scale by the fp8_attention_softmax lowering pass.
33+
34+
Returns True if FP8 normalization was configured (caller must set decomposable=False)."""
35+
if ctx.current_node is None:
36+
return False
37+
scale_val = ctx.current_node.meta.get("_fp8_softmax_scale")
38+
if scale_val is None:
39+
return False
40+
scale_tensor = get_trt_tensor(
41+
ctx,
42+
torch.tensor(scale_val, dtype=torch.float32),
43+
name + "_softmax_fp8_scale",
44+
dtype=torch.float32,
45+
)
46+
attention_layer.normalization_quantize_to_type = trt.DataType.FP8
47+
attention_layer.normalization_quantize_scale = scale_tensor
48+
return True
49+
1950

2051
def tril(
2152
ctx: ConversionContext,
@@ -164,6 +195,18 @@ def scaled_dot_product_attention(
164195
Returns:
165196
TRTTensor: Attention output tensor with shape [batch, heads, seq_len, head_dim]
166197
"""
198+
# When FP8 softmax normalization is active (modelopt FP8 MHA pattern) TRT's
199+
# FP8 MHA fusion requires the Q/DQ output to feed IAttention via a single
200+
# same-dtype Mul; any HALF<->FLOAT cast inserted by the default dynamic
201+
# 1/sqrt(D) computation breaks the fusion. Use a static same-dtype scalar
202+
# scale computed from the concrete head_dim.
203+
fp8_norm_active = (
204+
ctx.current_node is not None
205+
and ctx.current_node.meta.get("_fp8_softmax_scale") is not None
206+
)
207+
if fp8_norm_active and scale is None and isinstance(query.shape[-1], int):
208+
scale = 1.0 / math.sqrt(query.shape[-1])
209+
167210
if scale is None:
168211
# 1 / math.sqrt(query.size(-1))
169212
q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1)
@@ -256,7 +299,8 @@ def scaled_dot_product_attention(
256299

257300
if mask_tensor is not None:
258301
attention_layer.mask = mask_tensor
259-
attention_layer.decomposable = True
302+
fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer)
303+
attention_layer.decomposable = not fp8_norm
260304
attention_output = attention_layer.get_output(0)
261305
return attention_output
262306

@@ -284,6 +328,13 @@ def scaled_dot_product_flash_attention(
284328
Optional[TRTTensor],
285329
Optional[TRTTensor],
286330
]:
331+
fp8_norm_active = (
332+
ctx.current_node is not None
333+
and ctx.current_node.meta.get("_fp8_softmax_scale") is not None
334+
)
335+
if fp8_norm_active and scale is None and isinstance(query.shape[-1], int):
336+
scale = 1.0 / math.sqrt(query.shape[-1])
337+
287338
if scale is None:
288339
# 1 / math.sqrt(query.size(-1))
289340
q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1)
@@ -314,7 +365,8 @@ def scaled_dot_product_flash_attention(
314365
)
315366
assert attention_layer is not None, "attention layer is None"
316367

317-
attention_layer.decomposable = True
368+
fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer)
369+
attention_layer.decomposable = not fp8_norm
318370

319371
attention_output = attention_layer.get_output(0)
320372
return attention_output, None, None, None, 0.0, 0.0, None, None, None
@@ -334,6 +386,13 @@ def scaled_dot_product_efficient_attention(
334386
is_causal: bool = False,
335387
scale: Optional[float] = None,
336388
) -> Tuple[TRTTensor, Optional[TRTTensor], Optional[TRTTensor], Optional[TRTTensor]]:
389+
fp8_norm_active = (
390+
ctx.current_node is not None
391+
and ctx.current_node.meta.get("_fp8_softmax_scale") is not None
392+
)
393+
if fp8_norm_active and scale is None and isinstance(query.shape[-1], int):
394+
scale = 1.0 / math.sqrt(query.shape[-1])
395+
337396
if scale is None:
338397
# 1 / math.sqrt(query.size(-1))
339398
q_dim = impl.shape.shape(ctx, target, source_ir, f"{name}_shape_q", query, -1)
@@ -450,7 +509,8 @@ def scaled_dot_product_efficient_attention(
450509
if mask_tensor is not None:
451510
attention_layer.mask = mask_tensor
452511

453-
attention_layer.decomposable = True
512+
fp8_norm = _maybe_set_fp8_softmax(ctx, name, attention_layer)
513+
attention_layer.decomposable = not fp8_norm
454514

455515
attention_output = attention_layer.get_output(0)
456516
return attention_output, None, None, None

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
trace_intermediate_node_outputs,
1111
)
1212

13+
from .annotate_fp8_sdpa import annotate_fp8_sdpa
1314
from .complex_graph_rewrite import complex_graph_detection
1415
from .constant_folding import constant_fold
1516
from .force_causal_efficient_attention import force_causal_efficient_attention
@@ -41,6 +42,7 @@
4142
remove_num_users_is_0_nodes,
4243
complex_graph_detection,
4344
force_causal_efficient_attention,
45+
annotate_fp8_sdpa,
4446
]
4547

4648
if not is_tegra_platform():
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
5+
6+
logger = logging.getLogger(__name__)
7+
8+
# FP8 E4M3 max. Softmax output is bounded to [0, 1], so 1/448 saturates at 1.0 exactly
9+
# and is data-independent (no calibration required for the softmax output scale).
10+
_FP8_E4M3_SOFTMAX_SCALE = 1.0 / 448.0
11+
12+
_SDPA_TARGETS = {
13+
torch.ops.aten.scaled_dot_product_attention.default,
14+
torch.ops.aten._scaled_dot_product_flash_attention.default,
15+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
16+
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
17+
}
18+
19+
20+
def _is_fp8_quantize_op(node: torch.fx.Node) -> bool:
21+
"""Return True when node is a tensorrt.quantize_op with FP8 dtype (exponent_bits=4)."""
22+
if node.op != "call_function":
23+
return False
24+
try:
25+
if node.target != torch.ops.tensorrt.quantize_op.default:
26+
return False
27+
except AttributeError:
28+
return False
29+
# args: (input, amax, num_bits, exponent_bits, ...)
30+
args = node.args
31+
return len(args) >= 4 and args[2] == 8 and args[3] == 4
32+
33+
34+
def annotate_fp8_sdpa(
35+
gm: torch.fx.GraphModule, settings: CompilationSettings
36+
) -> torch.fx.GraphModule:
37+
"""Annotate SDPA nodes whose Q, K, V inputs are all FP8-quantized.
38+
39+
Detects the pattern emitted by modelopt when an attention module is
40+
registered via ``register_attention_for_kv_quant``, which wraps the
41+
Q, K, V arguments to ``F.scaled_dot_product_attention`` with
42+
``q_bmm_quantizer``, ``k_bmm_quantizer``, ``v_bmm_quantizer``:
43+
44+
q_fp8 = quantize_op(q, amax_q, num_bits=8, exponent_bits=4, ...)
45+
k_fp8 = quantize_op(k, amax_k, num_bits=8, exponent_bits=4, ...)
46+
v_fp8 = quantize_op(v, amax_v, num_bits=8, exponent_bits=4, ...)
47+
out = scaled_dot_product_attention(q_fp8, k_fp8, v_fp8, ...)
48+
49+
When all three inputs match this pattern the pass sets
50+
``node.meta["_fp8_softmax_scale"] = 1/448`` on the SDPA node so the
51+
attention converter can set ``IAttention.normalization_quantize_to_type
52+
= FP8`` and ``IAttention.normalization_quantize_scale``, which TRT
53+
requires to fuse into the ``_gemm_mha_v2`` FP8 MHA kernel.
54+
"""
55+
changed = False
56+
for node in gm.graph.nodes:
57+
if node.op != "call_function" or node.target not in _SDPA_TARGETS:
58+
continue
59+
if len(node.args) < 3:
60+
continue
61+
q_node, k_node, v_node = node.args[0], node.args[1], node.args[2]
62+
if not all(
63+
isinstance(n, torch.fx.Node) and _is_fp8_quantize_op(n)
64+
for n in (q_node, k_node, v_node)
65+
):
66+
continue
67+
node.meta["_fp8_softmax_scale"] = _FP8_E4M3_SOFTMAX_SCALE
68+
changed = True
69+
logger.debug(
70+
f"Annotated SDPA node {node.name} with FP8 softmax scale "
71+
f"{_FP8_E4M3_SOFTMAX_SCALE} (Q/K/V inputs are FP8-quantized)"
72+
)
73+
74+
if changed:
75+
logger.debug("FP8 SDPA softmax annotation complete")
76+
return gm

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ docs = [
102102
"pillow",
103103
]
104104

105-
quantization = ["nvidia-modelopt[hf]>=0.43.0"]
105+
quantization = [
106+
"nvidia-modelopt[hf] @ git+https://github.com/NVIDIA/Model-Optimizer.git@main",
107+
"transformers>=5.5.4",
108+
]
106109

107110
[project.urls]
108111
Homepage = "https://pytorch.org/tensorrt"

0 commit comments

Comments
 (0)