From 066452f1dcceac399be9f2f1b7fc1fafdbf8320c Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 2 Aug 2025 16:01:30 -0700 Subject: [PATCH 1/9] add sliding window support for Gemma3 --- tools/llm/torchtrt_ext/register_sdpa.py | 5 +- tools/llm/torchtrt_ext/sdpa_converter.py | 64 +++++++++++++++++++++++- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 90a00a5798..b04a7e5b3a 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,6 +1,7 @@ import copy import logging import operator +from re import I from typing import Callable, Sequence, Tuple import torch @@ -89,7 +90,9 @@ def replace_variants_of_sdpa( logger.warning( f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." ) - modified_input_args = (query, key, value, None, dropout_p, True) + # TODO: lan to figure out why is_causal is always False in google/gemma-3-1b-it, as in the config file it should be every 5 sliding window layer followed by a full attention layer + # also to figure out why the attn_mask passed in from transformers is not working + modified_input_args = (query, key, value, None, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 47083c7b48..9779294b77 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -27,7 +27,53 @@ def tril( name: str, row: TRTTensor, col: TRTTensor, + sliding_window_size: Optional[int] = None, ) -> TRTTensor: + + row_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 + ) + col_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 + ) + row_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1 + ) + col_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0 + ) + # sub will return the following mask tensor: + # [[0, -1, -2, -3], + # [1, 0, -1, -2], + # [2, 1, 0, -1], + # [3, 2, 1, 0]] + mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor + ) + ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0) + if sliding_window_size is None: + # return the following lower triangular mask includes the main diagonal: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False], + # 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]]) + return ge_0_mask + + lt_window_mask = impl.elementwise.lt( + ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size + ) + mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask + ) + # return the following mask if sliding_window_size is 3: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False], + # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) + return mask + row_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 ) @@ -66,7 +112,7 @@ def scaled_dot_product_attention( # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - is_causal = True + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) query_dtype = query.dtype @@ -136,7 +182,21 @@ def scaled_dot_product_attention( S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + if is_causal: + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + else: + # hard code the sliding window size to 512 for now + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S, 512) + # TODO: lan to figure out why attn_mask passed in from transformers is not working + # tried both 2d and 4d, but both are not working, hence the following code is commented out + # assert len(attn_mask.shape) in [2, 4], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}" + # if len(attn_mask.shape) == 4: + # if attn_mask.shape[0] != 1: + # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1) + # if attn_mask.shape[1] != 1: + # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1) + # attn_mask = impl.squeeze.squeeze(ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1)) + # tril_tensor = attn_mask temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor From a58d17bd1ff743d071c8852a435fbe3a021249cd Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 2 Aug 2025 16:19:45 -0700 Subject: [PATCH 2/9] test --- tools/llm/torchtrt_ext/register_sdpa.py | 1 - tools/llm/torchtrt_ext/sdpa_converter.py | 19 ------------------- 2 files changed, 20 deletions(-) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index b04a7e5b3a..d8472cba55 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,7 +1,6 @@ import copy import logging import operator -from re import I from typing import Callable, Sequence, Tuple import torch diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 9779294b77..8f4ba4e32f 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -74,25 +74,6 @@ def tril( # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) return mask - row_arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 - ) - row_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] - ) - - col_arange_tensor = impl.arange.arange( - ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 - ) - col_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] - ) - - mask = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor - ) - return mask - @torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( torch.nn.functional.scaled_dot_product_attention, From a65f0f1ae5792470feb02d71163ffeb8cf2bd91b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 14 Aug 2025 20:19:06 -0700 Subject: [PATCH 3/9] add test case --- tools/llm/run_llm.py | 2 +- tools/llm/test_trt_sdpa.py | 68 +++++++++++++++ tools/llm/torchtrt_ext/register_sdpa.py | 6 +- tools/llm/torchtrt_ext/sdpa_converter.py | 102 ++++++++++++++--------- tools/llm/utils.py | 1 - 5 files changed, 136 insertions(+), 43 deletions(-) create mode 100644 tools/llm/test_trt_sdpa.py diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..5647a10a7a 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -116,7 +116,7 @@ def compile_torchtrt(model, input_ids, args): use_fp32_acc=use_fp32_acc, device=DEVICE, disable_tf32=True, - use_python_runtime=True, + use_python_runtime=False, debug=args.debug, offload_module_to_cpu=True, min_block_size=args.min_block_size, diff --git a/tools/llm/test_trt_sdpa.py b/tools/llm/test_trt_sdpa.py new file mode 100644 index 0000000000..28827c0184 --- /dev/null +++ b/tools/llm/test_trt_sdpa.py @@ -0,0 +1,68 @@ +import torch +import torch_tensorrt +from torch.export import Dim +from torchtrt_ext import register_sdpa + + +class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + + def forward(self, query, key, value, attn_mask): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, + enable_math=False, + enable_mem_efficient=True, + ): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask, 0.0, False, scale=0.0625 + ) + + +dtype = torch.float32 + +dyn_dim = Dim("dyn_dim", min=3, max=32) + +query = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() +key = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() +value = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() +attn_mask = torch.ones((13, 13), dtype=torch.bool).tril(diagonal=0).cuda() +inputs = (query, key, value, attn_mask) + +model = SimpleNetwork().eval().cuda() +output_pyt = model(*inputs) +exp_program = torch.export.export( + model, + inputs, + strict=False, + dynamic_shapes={ + "query": {2: dyn_dim}, + "key": {2: dyn_dim}, + "value": {2: dyn_dim}, + "attn_mask": {0: dyn_dim, 1: dyn_dim}, + }, +) +DEBUG_LOGGING_DIR = "./debug_logs" +with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=DEBUG_LOGGING_DIR, + capture_fx_graph_after=["complex_graph_detection"], + save_engine_profile=True, + profile_format="trex", + engine_builder_monitor=True, +): + trt_model = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + enabled_precisions={dtype}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + truncate_double=True, + use_python_runtime=False, + ) + outputs_trt = trt_model(*inputs) + breakpoint() + assert torch.allclose(output_pyt, outputs_trt, rtol=1e-2, atol=1e-2) + +print("Done") diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index d8472cba55..c5f1f68665 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -89,9 +89,9 @@ def replace_variants_of_sdpa( logger.warning( f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." ) - # TODO: lan to figure out why is_causal is always False in google/gemma-3-1b-it, as in the config file it should be every 5 sliding window layer followed by a full attention layer - # also to figure out why the attn_mask passed in from transformers is not working - modified_input_args = (query, key, value, None, dropout_p, is_causal) + # TODO: lan to figure out why the attn_mask passed in from transformers is not working + # modified_input_args = (query, key, value, None, dropout_p, True) + modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale with gm.graph.inserting_after(node): diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 8f4ba4e32f..fdb04022da 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -161,51 +161,77 @@ def scaled_dot_product_attention( L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2) if S < 0: S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) - # generate the mask tensor if is_causal: tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) else: - # hard code the sliding window size to 512 for now - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S, 512) # TODO: lan to figure out why attn_mask passed in from transformers is not working - # tried both 2d and 4d, but both are not working, hence the following code is commented out - # assert len(attn_mask.shape) in [2, 4], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}" - # if len(attn_mask.shape) == 4: - # if attn_mask.shape[0] != 1: - # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1) - # if attn_mask.shape[1] != 1: - # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1) - # attn_mask = impl.squeeze.squeeze(ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1)) - # tril_tensor = attn_mask - - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) + # tried both 2d and 4d, but both are not working + assert len(attn_mask.shape) in [ + 2, + 4, + ], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}" + if len(attn_mask.shape) == 4: + if attn_mask.shape[0] != 1: + attn_mask = impl.slice.slice_op( + ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1 + ) + if attn_mask.shape[1] != 1: + attn_mask = impl.slice.slice_op( + ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1 + ) + attn_mask = impl.squeeze.squeeze( + ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1) + ) + tril_tensor = attn_mask - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) + # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this + attn_bias_via_where = True + if attn_bias_via_where: + attn_bias = impl.condition.where( + ctx, + target, + source_ir, + name + "_where", + torch.tensor(0.0, dtype=torch.float32).cuda(), + torch.tensor(-float("inf"), dtype=torch.float32).cuda(), + tril_tensor, + ) + else: + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + temp_mask = cast_trt_tensor( + ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir + ) + temp_mask = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_-inf", temp_mask, float("-inf") + ) + attn_bias = temp_mask - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir + ) + + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) scaled_add_attn_bias = impl.elementwise.add( ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..c56aa9b490 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -179,7 +179,6 @@ def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_tok num_tokens_generated = 0 kv_cache = get_zeroed_dynamic_cache_inputs(model) last_position_id = position_ids[-1, -1].item() - breakpoint() while num_tokens_generated < num_output_tokens: is_generate = False if input_seq.shape[1] > 1 else True position_ids = ( From 47abe2c352f4d5e117ffa23f12a9bfac01934089 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 14 Aug 2025 20:22:14 -0700 Subject: [PATCH 4/9] test --- tools/llm/torchtrt_ext/sdpa_converter.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index fdb04022da..7d467942bd 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -201,13 +201,6 @@ def scaled_dot_product_attention( temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor ) - temp_mask = cast_trt_tensor( - ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir - ) - temp_mask = impl.elementwise.mul( - ctx, target, source_ir, name + "_mul_-inf", temp_mask, float("-inf") - ) - attn_bias = temp_mask # This need_mask determines if we want to use the causal mask or not # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. From 779e17477b7aba20375ac9e14d23f1824b8a774b Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 18 Aug 2025 12:35:15 -0700 Subject: [PATCH 5/9] resolve the attn_mask nan issue --- tools/llm/test_trt_sdpa.py | 1 + tools/llm/torchtrt_ext/sdpa_converter.py | 92 ++++++++++++------------ 2 files changed, 49 insertions(+), 44 deletions(-) diff --git a/tools/llm/test_trt_sdpa.py b/tools/llm/test_trt_sdpa.py index 28827c0184..5691b206df 100644 --- a/tools/llm/test_trt_sdpa.py +++ b/tools/llm/test_trt_sdpa.py @@ -13,6 +13,7 @@ def forward(self, query, key, value, attn_mask): enable_flash=False, enable_math=False, enable_mem_efficient=True, + enable_cudnn=False, ): return torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask, 0.0, False, scale=0.0625 diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index 7d467942bd..c793851471 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -162,11 +162,7 @@ def scaled_dot_product_attention( if S < 0: S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) # generate the mask tensor - if is_causal: - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - else: - # TODO: lan to figure out why attn_mask passed in from transformers is not working - # tried both 2d and 4d, but both are not working + if not is_causal: assert len(attn_mask.shape) in [ 2, 4, @@ -183,48 +179,56 @@ def scaled_dot_product_attention( attn_mask = impl.squeeze.squeeze( ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1) ) - tril_tensor = attn_mask - - # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this - attn_bias_via_where = True - if attn_bias_via_where: - attn_bias = impl.condition.where( - ctx, - target, - source_ir, - name + "_where", - torch.tensor(0.0, dtype=torch.float32).cuda(), - torch.tensor(-float("inf"), dtype=torch.float32).cuda(), - tril_tensor, - ) + attn_bias = attn_mask else: - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this + attn_bias_via_where = True + if attn_bias_via_where: + attn_bias = impl.condition.where( + ctx, + target, + source_ir, + name + "_where", + torch.tensor(0.0, dtype=torch.float32).cuda(), + torch.tensor(-float("inf"), dtype=torch.float32).cuda(), + tril_tensor, + ) + else: + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir - ) + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq( + ctx, target, source_ir, name + "_eq", L, S + ) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, + temp_mask, + query_dtype, + name + "_casted_bool", + target, + source_ir, + ) - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) scaled_add_attn_bias = impl.elementwise.add( ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias From 78d96385cca604766700aa18672545591f6280dc Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 17 Jul 2025 11:13:45 -0700 Subject: [PATCH 6/9] Index converter dynamic cases fix --- .../dynamo/conversion/aten_ops_converters.py | 4 ++- tests/py/dynamo/conversion/test_index_aten.py | 26 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f1a7f9a8fc..6d2f8768ab 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -392,7 +392,9 @@ def index_dtype_validator( @dynamo_tensorrt_converter( - torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator + torch.ops.aten.index.Tensor, + capability_validator=index_dtype_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 8e21f945dc..fc4a70b1ff 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -168,7 +168,31 @@ def forward(self, input): dtype=torch.float32, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_specs) + self.run_test_with_dynamic_shape( + TestModule(), input_specs, use_dynamo_tracer=True + ) + + +class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase): + def test_index_input_non_dynamic_index_dynamic(self): + class TestIndexWithRuntimeIndex(torch.nn.Module): + def forward(self, x): + mask = x > 0 + idx = torch.nonzero(mask, as_tuple=True) + return torch.ops.aten.index.Tensor(x, idx) + + input_specs = [ + Input( + min_shape=(2, 2), + opt_shape=(2, 2), + max_shape=(8, 8), + dtype=torch.float32, + ), + ] + # In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True + self.run_test_with_dynamic_shape( + TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True + ) if __name__ == "__main__": From 6d0d87ce8e5ad15a61dda1ccad4b398fc02f5178 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 31 Jul 2025 15:53:21 -0700 Subject: [PATCH 7/9] support for boolean indices --- .../dynamo/conversion/aten_ops_converters.py | 6 +- .../dynamo/conversion/impl/select.py | 62 ++++++++++++++++++- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6d2f8768ab..591f5878a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -386,7 +386,11 @@ def index_dtype_validator( for ind in index: if ind is not None: val = ind.meta.get("val") - if val is not None and val.dtype not in (torch.int32, torch.int64): + if val is not None and val.dtype not in ( + torch.int32, + torch.int64, + torch.bool, + ): return False return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index fe6ade2e68..7640035959 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -53,6 +53,65 @@ def select( return layer.get_output(0) +def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: + if isinstance(tensor, (TRTTensor)): + val = tensor.meta.get("val") + if val is not None and val.dtype is torch.bool: + return True + return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool + + +def expand_boolean_indices( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], +) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + for i, ind in enumerate(indices): + if ind is not None and is_boolean_tensor(ind): + _LOGGER.debug( + f"Boolean index detected at position {i}, converting with nonzero()" + ) + + mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}") + + nonzero_layer = ctx.net.add_non_zero(mask_tensor) + set_layer_name( + nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir + ) + nonzero_indices = nonzero_layer.get_output(0) + + # nonzero returns shape [N, dims], we need to extract dim i + if len(indices) == 1: + # x[mask] — 1D mask + squeeze_layer = ctx.net.add_shuffle(nonzero_indices) + squeeze_layer.reshape_dims = (-1,) + set_layer_name( + squeeze_layer, + target, + name + f"_bool_nonzero_squeeze_{i}", + source_ir, + ) + squeezed_index = squeeze_layer.get_output(0) + ind = squeezed_index + else: + # Advanced multi-axis mask: extract index i from shape [N, D] + gather_axis = 1 # dim index + gather_layer = ctx.net.add_gather( + nonzero_indices, + get_trt_tensor(ctx, i, name + f"_dim_index_{i}"), + gather_axis, + ) + set_layer_name( + gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir + ) + extracted_index = gather_layer.get_output(0) + ind = extracted_index + return indices + + def index( ctx: ConversionContext, target: Target, @@ -63,8 +122,6 @@ def index( ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # check if the input is dynamic - dynamic_shape = has_dynamic_shape(input.shape) # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( @@ -78,6 +135,7 @@ def index( # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None + indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) for i, ind in enumerate(indices): if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") From 8b120f74e15412097404ff62e1774881fd142576 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 17 Jul 2025 11:13:45 -0700 Subject: [PATCH 8/9] Index converter dynamic cases fix --- .../dynamo/conversion/impl/select.py | 3 +- tools/llm/test_trt_sdpa.py | 69 ------------ tools/llm/torchtrt_ext/register_sdpa.py | 5 - tools/llm/torchtrt_ext/sdpa_converter.py | 101 ++++++------------ 4 files changed, 37 insertions(+), 141 deletions(-) delete mode 100644 tools/llm/test_trt_sdpa.py diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 7640035959..3e323c538f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -19,7 +19,6 @@ from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.fx.converters.converter_utils import ( - has_dynamic_shape, set_layer_name, ) from torch_tensorrt.fx.types import TRTTensor @@ -55,6 +54,8 @@ def select( def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: if isinstance(tensor, (TRTTensor)): + if getattr(tensor, "meta", None) is None: + return tensor.dtype == torch.bool val = tensor.meta.get("val") if val is not None and val.dtype is torch.bool: return True diff --git a/tools/llm/test_trt_sdpa.py b/tools/llm/test_trt_sdpa.py deleted file mode 100644 index 5691b206df..0000000000 --- a/tools/llm/test_trt_sdpa.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch_tensorrt -from torch.export import Dim -from torchtrt_ext import register_sdpa - - -class SimpleNetwork(torch.nn.Module): - def __init__(self): - super(SimpleNetwork, self).__init__() - - def forward(self, query, key, value, attn_mask): - with torch.backends.cuda.sdp_kernel( - enable_flash=False, - enable_math=False, - enable_mem_efficient=True, - enable_cudnn=False, - ): - return torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask, 0.0, False, scale=0.0625 - ) - - -dtype = torch.float32 - -dyn_dim = Dim("dyn_dim", min=3, max=32) - -query = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() -key = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() -value = torch.randn((1, 4, 13, 256), dtype=dtype).cuda() -attn_mask = torch.ones((13, 13), dtype=torch.bool).tril(diagonal=0).cuda() -inputs = (query, key, value, attn_mask) - -model = SimpleNetwork().eval().cuda() -output_pyt = model(*inputs) -exp_program = torch.export.export( - model, - inputs, - strict=False, - dynamic_shapes={ - "query": {2: dyn_dim}, - "key": {2: dyn_dim}, - "value": {2: dyn_dim}, - "attn_mask": {0: dyn_dim, 1: dyn_dim}, - }, -) -DEBUG_LOGGING_DIR = "./debug_logs" -with torch_tensorrt.dynamo.Debugger( - "graphs", - logging_dir=DEBUG_LOGGING_DIR, - capture_fx_graph_after=["complex_graph_detection"], - save_engine_profile=True, - profile_format="trex", - engine_builder_monitor=True, -): - trt_model = torch_tensorrt.dynamo.compile( - exp_program, - inputs=inputs, - enabled_precisions={dtype}, - min_block_size=1, - cache_built_engines=False, - reuse_cached_engines=False, - truncate_double=True, - use_python_runtime=False, - ) - outputs_trt = trt_model(*inputs) - breakpoint() - assert torch.allclose(output_pyt, outputs_trt, rtol=1e-2, atol=1e-2) - -print("Done") diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index c5f1f68665..bf63801276 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -86,11 +86,6 @@ def replace_variants_of_sdpa( f"Unexpected number of arguments for {node.target} in the graph" ) - logger.warning( - f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." - ) - # TODO: lan to figure out why the attn_mask passed in from transformers is not working - # modified_input_args = (query, key, value, None, dropout_p, True) modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal) # Create a new node with torch.nn.functional.scaled_dot_product_attention # The input args is (query, key, value, is_causal). kwargs has scale diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index c793851471..03a960edc4 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -162,73 +162,42 @@ def scaled_dot_product_attention( if S < 0: S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) # generate the mask tensor - if not is_causal: - assert len(attn_mask.shape) in [ - 2, - 4, - ], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}" - if len(attn_mask.shape) == 4: - if attn_mask.shape[0] != 1: - attn_mask = impl.slice.slice_op( - ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1 - ) - if attn_mask.shape[1] != 1: - attn_mask = impl.slice.slice_op( - ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1 - ) - attn_mask = impl.squeeze.squeeze( - ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1) - ) - attn_bias = attn_mask - else: + if is_causal: tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) - # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this - attn_bias_via_where = True - if attn_bias_via_where: - attn_bias = impl.condition.where( - ctx, - target, - source_ir, - name + "_where", - torch.tensor(0.0, dtype=torch.float32).cuda(), - torch.tensor(-float("inf"), dtype=torch.float32).cuda(), - tril_tensor, - ) - else: - temp_mask = impl.unary.logical_not( - ctx, target, source_ir, name + "_logical_not", tril_tensor - ) - - # This need_mask determines if we want to use the causal mask or not - # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. - # So need_mask will be all False values in this case. - # TODO: Implement more general case where L != 1 and S != L - need_mask = impl.elementwise.eq( - ctx, target, source_ir, name + "_eq", L, S - ) - temp_mask = impl.elementwise.logical_and( - ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask - ) - temp_mask_casted = cast_trt_tensor( - ctx, - temp_mask, - query_dtype, - name + "_casted_bool", - target, - source_ir, - ) - - one_minus_temp_mask = impl.elementwise.sub( - ctx, - target, - source_ir, - name + "_one_minus_temp_mask", - 1.0, - temp_mask_casted, - ) - attn_bias = impl.unary.log( - ctx, target, source_ir, name + "_log", one_minus_temp_mask - ) + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + + # This need_mask determines if we want to use the causal mask or not + # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask. + # So need_mask will be all False values in this case. + # TODO: Implement more general case where L != 1 and S != L + need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S) + temp_mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask + ) + temp_mask_casted = cast_trt_tensor( + ctx, + temp_mask, + query_dtype, + name + "_casted_bool", + target, + source_ir, + ) + + one_minus_temp_mask = impl.elementwise.sub( + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, + ) + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + else: + attn_bias = attn_mask scaled_add_attn_bias = impl.elementwise.add( ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias From bd28290e0fc298cf73716ff58d6681c8e76ac4ce Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 21 Aug 2025 15:04:24 -0700 Subject: [PATCH 9/9] get index change in --- py/torch_tensorrt/dynamo/conversion/impl/select.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 0d1e63e35c..b657c19a88 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -54,7 +54,7 @@ def select( def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: if isinstance(tensor, (TRTTensor)): if getattr(tensor, "meta", None) is None: - return tensor.dtype == torch.bool + return tensor.dtype == trt.DataType.BOOL val = tensor.meta.get("val") if val is not None and val.dtype is torch.bool: return True