diff --git a/docsrc/user_guide/mixed_precision.rst b/docsrc/user_guide/mixed_precision.rst index dca0b033e6..3ea475c94b 100644 --- a/docsrc/user_guide/mixed_precision.rst +++ b/docsrc/user_guide/mixed_precision.rst @@ -1,7 +1,7 @@ .. _mixed_precision: Compile Mixed Precision models with Torch-TensorRT -==================================== +=================================================== .. currentmodule:: torch_tensorrt.dynamo .. automodule:: torch_tensorrt.dynamo @@ -9,7 +9,10 @@ Compile Mixed Precision models with Torch-TensorRT :undoc-members: :show-inheritance: -Consider the following Pytorch model which explicitly casts intermediate layer to run in FP16. +Explicit Typing +--------------- + +Consider the following PyTorch model which explicitly casts intermediate layer to run in FP16. .. code-block:: python @@ -54,6 +57,7 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option .. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions. + .. code-block:: python inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] @@ -62,7 +66,7 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option with torch_tensorrt.logging.debug(): trt_gm = torch_tensorrt.dynamo.compile(ep, inputs=inputs, - use_explicit_typing=True + use_explicit_typing=True, debug=True) # Debug log info @@ -71,4 +75,36 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option # Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata: # Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata: -Now the ``linear2`` layer runs in FP16 as shown in the above logs. \ No newline at end of file +Now the ``linear2`` layer runs in FP16 as shown in the above logs. + + + +FP32 Accumulation +----------------- + +When ``use_fp32_acc=True`` is set, Torch-TensorRT will attempt to use FP32 accumulation for matmul layers, even if the input and output tensors are in FP16. This is particularly useful for models that are sensitive to numerical errors introduced by lower-precision accumulation. + +.. important:: + + When enabling ``use_fp32_acc=True``, **explicit typing must be enabled** by setting ``use_explicit_typing=True``. Without ``use_explicit_typing=True``, the accumulation type may not be properly respected, and you may not see the intended numerical benefits. + +.. code-block:: python + + inputs = [torch.randn((1, 10), dtype=torch.float16).cuda()] + mod = MyModule().eval().cuda() + ep = torch.export.export(mod, tuple(inputs)) + with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=inputs, + use_fp32_acc=True, + use_explicit_typing=True, # Explicit typing must be enabled + debug=True + ) + + # Debug log info + # Layers: + # Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ ... ], Outputs: [ ... ], Format/Datatype: Half, Accumulation: Float + # ... + +For more information on these settings, see the explicit typing examples above. \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6006484f19..831a83ed37 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -393,7 +393,9 @@ def index_dtype_validator( @dynamo_tensorrt_converter( - torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator + torch.ops.aten.index.Tensor, + capability_validator=index_dtype_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 7d7f4274ff..8a26fcedd0 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -24,6 +24,7 @@ from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata from torch_tensorrt import _enums +from torch_tensorrt._utils import is_tensorrt_version_supported from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -31,7 +32,7 @@ ConverterRegistry, DynamoConverterImplSignature, ) -from torch_tensorrt._utils import is_tensorrt_version_supported + from ..types import Shape, TRTDataType, TRTLayer, TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -72,6 +73,9 @@ def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str: # If the provided data is a scalar, return it as is elif isinstance(metadata, (int, float, bool)): return f"{metadata}@Python-{type(metadata)}" + # If the provided data is a SymInt, return it as is + elif isinstance(metadata, (torch.SymInt)): + return f"{metadata}@SymInt" # If the provided data is a sequence, recursively parse it elif isinstance(metadata, collections.abc.Sequence): formatted_str = "(" diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 65e4f53328..910b851b3c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -48,17 +48,25 @@ def matrix_multiply( input, other = broadcast( ctx, input, other, f"{name}_input", f"{name}_other", preset_diff ) - if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED): - promoted_type = _enums.dtype._from( - torch.promote_types( - _enums.dtype._from(input.dtype).to(torch.dtype), - _enums.dtype._from(other.dtype).to(torch.dtype), - ) + if ( + ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + and ctx.compilation_settings.use_fp32_acc + ): + input = cast_trt_tensor(ctx, input, torch.float32, f"{name}_input_casted") + other = cast_trt_tensor(ctx, other, torch.float32, f"{name}_other_casted") + + matmul_layer = ctx.net.add_matrix_multiply( + input, input_matrix_op, other, other_matrix_op + ) + matmul_output = matmul_layer.get_output(0) + + if ( + ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + and ctx.compilation_settings.use_fp32_acc + ): + matmul_output = cast_trt_tensor( + ctx, matmul_output, torch.float16, f"{name}_output_casted" ) - trt_promoted_type = promoted_type.to(trt.DataType) - input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted") - other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted") - layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) - set_layer_name(layer, target, name, source_ir) - return layer.get_output(0) + set_layer_name(matmul_layer, target, name, source_ir) + return matmul_output diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 516c371e48..1fc1b9b420 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -5,7 +5,6 @@ from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.utils import is_tegra_platform -from .accumulate_fp32_matmul import accumulate_fp32_matmul from .complex_graph_rewrite import complex_graph_detection from .constant_folding import constant_fold from .fuse_prims_broadcast import fuse_prims_broadcast @@ -24,7 +23,6 @@ fuse_prims_broadcast, replace_max_pool_with_indices, remove_assert_nodes, - accumulate_fp32_matmul, remove_num_users_is_0_nodes, complex_graph_detection, ] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py deleted file mode 100644 index 282693d299..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py +++ /dev/null @@ -1,115 +0,0 @@ -import logging - -import torch -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -logger = logging.getLogger(__name__) - - -def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ - Splits all `torch.ops.aten.addmm.default` nodes in the FX graph into separate - `add` and `mm` nodes. This is useful for passes that want to insert additional - logic (such as FP32 accumulation) specifically around the matrix multiplication - operation, rather than the fused addmm. - - Args: - gm (torch.fx.GraphModule): The FX graph module to transform. - - Returns: - torch.fx.GraphModule: The modified FX graph module with addmm nodes split. - """ - target = torch.ops.aten.addmm.default - addmm_nodes = [node for node in gm.graph.nodes if node.target == target] - for addmm_node in addmm_nodes: - bias, mat1, mat2 = addmm_node.all_input_nodes - beta = addmm_node.kwargs.get("beta") - alpha = addmm_node.kwargs.get("alpha") - - with gm.graph.inserting_before(addmm_node): - mm_node = gm.graph.call_function( - torch.ops.aten.mm.default, - args=(mat1, mat2), - ) - if alpha: - mm_node = gm.graph.call_function( - torch.ops.aten.mul.Tensor, - args=(mm_node, alpha), - ) - - if beta: - bias = gm.graph.call_function( - torch.ops.aten.mul.Tensor, - args=(bias, beta), - ) - add_node = gm.graph.call_function( - torch.ops.aten.add.Tensor, - args=(bias, mm_node), - ) - - addmm_node.replace_all_uses_with(add_node, propagate_meta=True) - gm.graph.erase_node(addmm_node) - - return gm - - -def accumulate_fp32_matmul( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Add cast to FP32/16 nodes around a matmul layer. This pattern is detected by TensorRT and will enable FP32 accumulation during execution.""" - if settings.use_fp32_acc: - matmul_targets = [ - torch.ops.aten.mm.default, - torch.ops.aten.bmm.default, - torch.ops.aten.matmul.default, - ] - - # Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes - split_addmm_nodes(gm) - - matmul_nodes = [ - node for node in gm.graph.nodes if node.target in matmul_targets - ] - for matmul_node in matmul_nodes: - # Prior to the matmul node, insert a cast to the 32-bit float32 node - node_inputs = matmul_node.all_input_nodes - - for node_input in node_inputs: - with gm.graph.inserting_before(matmul_node): - node_32bit = gm.graph.call_function( - torch.ops.aten._to_copy.default, - args=(node_input,), - kwargs={"dtype": torch.float32}, - ) - - # Replace the input to matmul node with new 32-bit cast node - matmul_node.replace_input_with(node_input, node_32bit) - - # Add a cast back to original precision - with gm.graph.inserting_after(matmul_node): - node_orig_precision = gm.graph.call_function( - torch.ops.aten._to_copy.default, - args=(matmul_node,), - kwargs={"dtype": torch.float16}, - ) - matmul_node.replace_all_uses_with( - node_orig_precision, propagate_meta=False - ) - # This is a hack. replace_all_uses_with isn't working here. It complains node_orig_precision is already being used before created. - node_orig_precision.replace_input_with( - node_orig_precision.all_input_nodes[0], matmul_node - ) - - gm = clean_up_graph_after_modifications(gm) - logger.debug( - f"Graph after enabling matmul layers to use FP32 accumulation:\n{gm.graph}" - ) - else: - logger.debug( - "Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings" - ) - - return gm diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 3197d9f7de..258449ad7b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -367,7 +367,8 @@ def compile(self) -> None: enabled_precisions=self.enabled_precisions, **self.additional_settings, ) - deallocate_module(self.original_model, delete_module=False) + if self.additional_settings.get("offload_module_to_cpu", False): + deallocate_module(self.original_model, delete_module=False) if self.enable_weight_streaming: self.set_weight_streaming_ctx(self.weight_streaming_budget)