Skip to content

feat: Add support for Groot N1.5 model #3736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions docsrc/user_guide/mixed_precision.rst
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
.. _mixed_precision:

Compile Mixed Precision models with Torch-TensorRT
====================================
===================================================
.. currentmodule:: torch_tensorrt.dynamo

.. automodule:: torch_tensorrt.dynamo
:members:
:undoc-members:
:show-inheritance:

Consider the following Pytorch model which explicitly casts intermediate layer to run in FP16.
Explicit Typing
---------------

Consider the following PyTorch model which explicitly casts intermediate layer to run in FP16.

.. code-block:: python

Expand Down Expand Up @@ -54,6 +57,7 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option

.. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions.


.. code-block:: python

inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()]
Expand All @@ -62,7 +66,7 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option
with torch_tensorrt.logging.debug():
trt_gm = torch_tensorrt.dynamo.compile(ep,
inputs=inputs,
use_explicit_typing=True
use_explicit_typing=True,
debug=True)

# Debug log info
Expand All @@ -71,4 +75,36 @@ the compilation setting ``use_explicit_typing=True``. Compiling with this option
# Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata:
# Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata:

Now the ``linear2`` layer runs in FP16 as shown in the above logs.
Now the ``linear2`` layer runs in FP16 as shown in the above logs.



FP32 Accumulation
-----------------

When ``use_fp32_acc=True`` is set, Torch-TensorRT will attempt to use FP32 accumulation for matmul layers, even if the input and output tensors are in FP16. This is particularly useful for models that are sensitive to numerical errors introduced by lower-precision accumulation.

.. important::

When enabling ``use_fp32_acc=True``, **explicit typing must be enabled** by setting ``use_explicit_typing=True``. Without ``use_explicit_typing=True``, the accumulation type may not be properly respected, and you may not see the intended numerical benefits.

.. code-block:: python

inputs = [torch.randn((1, 10), dtype=torch.float16).cuda()]
mod = MyModule().eval().cuda()
ep = torch.export.export(mod, tuple(inputs))
with torch_tensorrt.logging.debug():
trt_gm = torch_tensorrt.dynamo.compile(
ep,
inputs=inputs,
use_fp32_acc=True,
use_explicit_typing=True, # Explicit typing must be enabled
debug=True
)

# Debug log info
# Layers:
# Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ ... ], Outputs: [ ... ], Format/Datatype: Half, Accumulation: Float
# ...

For more information on these settings, see the explicit typing examples above.
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,9 @@ def index_dtype_validator(


@dynamo_tensorrt_converter(
torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
torch.ops.aten.index.Tensor,
capability_validator=index_dtype_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
Expand Down
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
from torch.fx.node import Argument, Target
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt import _enums
from torch_tensorrt._utils import is_tensorrt_version_supported
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
ConverterRegistry,
DynamoConverterImplSignature,
)
from torch_tensorrt._utils import is_tensorrt_version_supported

from ..types import Shape, TRTDataType, TRTLayer, TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,6 +73,9 @@ def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str:
# If the provided data is a scalar, return it as is
elif isinstance(metadata, (int, float, bool)):
return f"{metadata}@Python-{type(metadata)}"
# If the provided data is a SymInt, return it as is
elif isinstance(metadata, (torch.SymInt)):
return f"{metadata}@SymInt"
# If the provided data is a sequence, recursively parse it
elif isinstance(metadata, collections.abc.Sequence):
formatted_str = "("
Expand Down
32 changes: 20 additions & 12 deletions py/torch_tensorrt/dynamo/conversion/impl/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,25 @@ def matrix_multiply(
input, other = broadcast(
ctx, input, other, f"{name}_input", f"{name}_other", preset_diff
)
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED):
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(input.dtype).to(torch.dtype),
_enums.dtype._from(other.dtype).to(torch.dtype),
)
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we start noting down these specific type behavior things somewhere, like in the contributor docs?

ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
and ctx.compilation_settings.use_fp32_acc
):
input = cast_trt_tensor(ctx, input, torch.float32, f"{name}_input_casted")
other = cast_trt_tensor(ctx, other, torch.float32, f"{name}_other_casted")

matmul_layer = ctx.net.add_matrix_multiply(
input, input_matrix_op, other, other_matrix_op
)
matmul_output = matmul_layer.get_output(0)

if (
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
and ctx.compilation_settings.use_fp32_acc
):
matmul_output = cast_trt_tensor(
ctx, matmul_output, torch.float16, f"{name}_output_casted"
)
trt_promoted_type = promoted_type.to(trt.DataType)
input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted")
other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted")

layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
set_layer_name(matmul_layer, target, name, source_ir)
return matmul_output
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.utils import is_tegra_platform

from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .complex_graph_rewrite import complex_graph_detection
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
Expand All @@ -24,7 +23,6 @@
fuse_prims_broadcast,
replace_max_pool_with_indices,
remove_assert_nodes,
accumulate_fp32_matmul,
remove_num_users_is_0_nodes,
complex_graph_detection,
]
Expand Down
115 changes: 0 additions & 115 deletions py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ def compile(self) -> None:
enabled_precisions=self.enabled_precisions,
**self.additional_settings,
)
deallocate_module(self.original_model, delete_module=False)
if self.additional_settings.get("offload_module_to_cpu", False):
deallocate_module(self.original_model, delete_module=False)
if self.enable_weight_streaming:
self.set_weight_streaming_ctx(self.weight_streaming_budget)

Expand Down
Loading