Skip to content

Commit 39e24f5

Browse files
GdoongMathewBorda
andauthored
Torch-Tensorrt Integration with LightningModule (#20808)
* feat: add `to_tensorrt` in the `LightningModule`. * refactor: fix `to_tensorrt` impl * test: add test_torch_tensorrt.py * add dependency in test requirement * limit the torch-tensorrt condition again * update tensorrt version * update tensorrt source * update test.txt * ci: add extra-index * ci: use find-links instead. works on my computer... * fix: fix bug in torch-tensorrt 2.8.0 * add find links in ci test. * chlog --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]>
1 parent 3ed9d4e commit 39e24f5

File tree

6 files changed

+281
-5
lines changed

6 files changed

+281
-5
lines changed

.github/workflows/ci-tests-pytorch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \
140140
-U --upgrade-strategy=eager --prefer-binary \
141141
-r requirements/_integrations/accelerators.txt \
142-
--extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}"
142+
--extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" --find-links="https://download.pytorch.org/whl/torch-tensorrt"
143143
pip list
144144
- name: Drop LAI from extensions
145145
if: ${{ matrix.pkg-name != 'lightning' }}

requirements/pytorch/test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in
1818
uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App
1919

2020
tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger`
21+
22+
--find-links https://download.pytorch.org/whl/torch-tensorrt
23+
torch-tensorrt; platform_system == "Linux" and python_version >= "3.12"

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
-
13+
- Added Torch-Tensorrt Integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))
1414

1515

1616
### Changed

src/lightning/pytorch/core/module.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414
"""The LightningModule - an nn.Module with many additional features."""
1515

16+
import copy
1617
import logging
1718
import numbers
1819
import weakref
1920
from collections.abc import Generator, Mapping, Sequence
20-
from contextlib import contextmanager
21+
from contextlib import contextmanager, nullcontext
2122
from io import BytesIO
2223
from pathlib import Path
2324
from typing import (
@@ -47,7 +48,7 @@
4748
from lightning.fabric.utilities.apply_func import convert_to_tensors
4849
from lightning.fabric.utilities.cloud_io import get_filesystem
4950
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
50-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5
51+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_5
5152
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
5253
from lightning.fabric.wrappers import _FabricOptimizer
5354
from lightning.pytorch.callbacks.callback import Callback
@@ -75,6 +76,7 @@
7576

7677
_ONNX_AVAILABLE = RequirementCache("onnx")
7778
_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript")
79+
_TORCH_TRT_AVAILABLE = RequirementCache("torch_tensorrt")
7880

7981
if TYPE_CHECKING:
8082
from torch.distributed.device_mesh import DeviceMesh
@@ -1570,6 +1572,117 @@ def forward(self, x):
15701572

15711573
return torchscript_module
15721574

1575+
@torch.no_grad()
1576+
def to_tensorrt(
1577+
self,
1578+
file_path: Optional[Union[str, Path, BytesIO]] = None,
1579+
input_sample: Optional[Any] = None,
1580+
ir: Literal["default", "dynamo", "ts"] = "default",
1581+
output_format: Literal["exported_program", "torchscript"] = "exported_program",
1582+
retrace: bool = False,
1583+
default_device: Union[str, torch.device] = "cuda",
1584+
**compile_kwargs: Any,
1585+
) -> Union[ScriptModule, torch.fx.GraphModule]:
1586+
"""Export the model to ScriptModule or GraphModule using TensorRT compile backend.
1587+
1588+
Args:
1589+
file_path: Path where to save the tensorrt model. Default: None (no file saved).
1590+
input_sample: inputs to be used during `torch_tensorrt.compile`.
1591+
Default: None (Use :attr:`example_input_array`).
1592+
ir: The IR mode to use for TensorRT compilation. Default: "default".
1593+
output_format: The format of the output model. Default: "exported_program".
1594+
retrace: Whether to retrace the model. Default: False.
1595+
default_device: The device to use for the model when the current model is not in CUDA. Default: "cuda".
1596+
**compile_kwargs: Additional arguments that will be passed to the TensorRT compile function.
1597+
1598+
Example::
1599+
1600+
class SimpleModel(LightningModule):
1601+
def __init__(self):
1602+
super().__init__()
1603+
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
1604+
1605+
def forward(self, x):
1606+
return torch.relu(self.l1(x.view(x.size(0), -1)
1607+
1608+
model = SimpleModel()
1609+
input_sample = torch.randn(1, 64)
1610+
exported_program = model.to_tensorrt(
1611+
file_path="export.ep",
1612+
inputs=input_sample,
1613+
)
1614+
1615+
"""
1616+
if not _TORCH_GREATER_EQUAL_2_2:
1617+
raise MisconfigurationException(
1618+
f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}."
1619+
)
1620+
1621+
if not _TORCH_TRT_AVAILABLE:
1622+
raise ModuleNotFoundError(
1623+
f"`{type(self).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. "
1624+
)
1625+
1626+
mode = self.training
1627+
device = self.device
1628+
if self.device.type != "cuda":
1629+
default_device = torch.device(default_device) if isinstance(default_device, str) else default_device
1630+
1631+
if not torch.cuda.is_available() or default_device.type != "cuda":
1632+
raise MisconfigurationException(
1633+
f"TensorRT only supports CUDA devices. The current device is {self.device}."
1634+
f" Please set the `default_device` argument to a CUDA device."
1635+
)
1636+
1637+
self.to(default_device)
1638+
1639+
if input_sample is None:
1640+
if self.example_input_array is None:
1641+
raise ValueError(
1642+
"Could not export to TensorRT since neither `input_sample` nor"
1643+
" `model.example_input_array` attribute is set."
1644+
)
1645+
input_sample = self.example_input_array
1646+
1647+
import torch_tensorrt
1648+
1649+
input_sample = copy.deepcopy((input_sample,) if isinstance(input_sample, torch.Tensor) else input_sample)
1650+
input_sample = self._on_before_batch_transfer(input_sample)
1651+
input_sample = self._apply_batch_transfer_handler(input_sample)
1652+
1653+
with _jit_is_scripting() if ir == "ts" else nullcontext():
1654+
trt_obj = torch_tensorrt.compile(
1655+
module=self.eval(),
1656+
ir=ir,
1657+
inputs=input_sample,
1658+
**compile_kwargs,
1659+
)
1660+
self.train(mode)
1661+
self.to(device)
1662+
1663+
if file_path is not None:
1664+
if ir == "ts":
1665+
if output_format != "torchscript":
1666+
raise ValueError(
1667+
"TensorRT with IR mode 'ts' only supports output format 'torchscript'."
1668+
f" The current output format is {output_format}."
1669+
)
1670+
assert isinstance(trt_obj, (torch.jit.ScriptModule, torch.jit.ScriptFunction)), (
1671+
f"Expected TensorRT object to be a ScriptModule, but got {type(trt_obj)}."
1672+
)
1673+
# Because of https://github.com/pytorch/TensorRT/issues/3775,
1674+
# we'll need to take special care for the ScriptModule
1675+
torch.jit.save(trt_obj, file_path)
1676+
else:
1677+
torch_tensorrt.save(
1678+
trt_obj,
1679+
file_path,
1680+
inputs=input_sample,
1681+
output_format=output_format,
1682+
retrace=retrace,
1683+
)
1684+
return trt_obj
1685+
15731686
@_restricted_classmethod
15741687
def load_from_checkpoint(
15751688
cls,

src/lightning/pytorch/utilities/testing/_runif.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from lightning.fabric.utilities.testing import _runif_reasons as _fabric_run_if
1919
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
20-
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE
20+
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE, _TORCH_TRT_AVAILABLE
2121
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _RICH_AVAILABLE
2222

2323
_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
@@ -43,6 +43,7 @@ def _runif_reasons(
4343
onnx: bool = False,
4444
linux_only: bool = False,
4545
onnxscript: bool = False,
46+
tensorrt: bool = False,
4647
) -> tuple[list[str], dict[str, bool]]:
4748
"""Construct reasons for pytest skipif.
4849
@@ -66,6 +67,7 @@ def _runif_reasons(
6667
sklearn: Require that scikit-learn is installed.
6768
onnx: Require that onnx is installed.
6869
onnxscript: Require that onnxscript is installed.
70+
tensorrt: Require that torch-tensorrt is installed.
6971
7072
"""
7173

@@ -102,4 +104,7 @@ def _runif_reasons(
102104
if onnxscript and not _ONNXSCRIPT_AVAILABLE:
103105
reasons.append("onnxscript")
104106

107+
if tensorrt and not _TORCH_TRT_AVAILABLE:
108+
reasons.append("torch-tensorrt")
109+
105110
return reasons, kwargs
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import os
2+
import re
3+
from io import BytesIO
4+
from pathlib import Path
5+
6+
import pytest
7+
import torch
8+
9+
import tests_pytorch.helpers.pipelines as pipes
10+
from lightning.pytorch.core.module import _TORCH_TRT_AVAILABLE
11+
from lightning.pytorch.demos.boring_classes import BoringModel
12+
from lightning.pytorch.utilities.exceptions import MisconfigurationException
13+
from tests_pytorch.helpers.runif import RunIf
14+
15+
16+
@RunIf(max_torch="2.2.0")
17+
def test_torch_minimum_version():
18+
model = BoringModel()
19+
with pytest.raises(
20+
MisconfigurationException,
21+
match=re.escape(f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}."),
22+
):
23+
model.to_tensorrt("model.trt")
24+
25+
26+
@pytest.mark.skipif(_TORCH_TRT_AVAILABLE, reason="Run this test only if tensorrt is not available.")
27+
@RunIf(min_torch="2.2.0")
28+
def test_missing_tensorrt_package():
29+
model = BoringModel()
30+
with pytest.raises(
31+
ModuleNotFoundError,
32+
match=re.escape(f"`{type(model).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. "),
33+
):
34+
model.to_tensorrt("model.trt")
35+
36+
37+
@RunIf(tensorrt=True, min_torch="2.2.0")
38+
def test_tensorrt_with_wrong_default_device(tmp_path):
39+
model = BoringModel()
40+
input_sample = torch.randn((1, 32))
41+
file_path = os.path.join(tmp_path, "model.trt")
42+
with pytest.raises(MisconfigurationException):
43+
model.to_tensorrt(file_path, input_sample, default_device="cpu")
44+
45+
46+
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
47+
def test_tensorrt_saves_with_input_sample(tmp_path):
48+
model = BoringModel()
49+
ori_device = model.device
50+
input_sample = torch.randn((1, 32))
51+
52+
file_path = os.path.join(tmp_path, "model.trt")
53+
model.to_tensorrt(file_path, input_sample)
54+
55+
assert os.path.isfile(file_path)
56+
assert os.path.getsize(file_path) > 4e2
57+
assert model.device == ori_device
58+
59+
file_path = Path(tmp_path) / "model.trt"
60+
model.to_tensorrt(file_path, input_sample)
61+
assert os.path.isfile(file_path)
62+
assert os.path.getsize(file_path) > 4e2
63+
assert model.device == ori_device
64+
65+
file_path = BytesIO()
66+
model.to_tensorrt(file_path, input_sample)
67+
assert len(file_path.getvalue()) > 4e2
68+
69+
70+
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
71+
def test_tensorrt_error_if_no_input(tmp_path):
72+
model = BoringModel()
73+
model.example_input_array = None
74+
file_path = os.path.join(tmp_path, "model.trt")
75+
76+
with pytest.raises(
77+
ValueError,
78+
match=r"Could not export to TensorRT since neither `input_sample` nor "
79+
r"`model.example_input_array` attribute is set.",
80+
):
81+
model.to_tensorrt(file_path)
82+
83+
84+
@RunIf(tensorrt=True, min_cuda_gpus=2, min_torch="2.2.0")
85+
def test_tensorrt_saves_on_multi_gpu(tmp_path):
86+
trainer_options = {
87+
"default_root_dir": tmp_path,
88+
"max_epochs": 1,
89+
"limit_train_batches": 10,
90+
"limit_val_batches": 10,
91+
"accelerator": "gpu",
92+
"devices": [0, 1],
93+
"strategy": "ddp_spawn",
94+
"enable_progress_bar": False,
95+
}
96+
97+
model = BoringModel()
98+
model.example_input_array = torch.randn((4, 32))
99+
100+
pipes.run_model_test(trainer_options, model, min_acc=0.08)
101+
102+
file_path = os.path.join(tmp_path, "model.trt")
103+
model.to_tensorrt(file_path)
104+
105+
assert os.path.exists(file_path)
106+
107+
108+
@pytest.mark.parametrize(
109+
("ir", "export_type"),
110+
[
111+
("default", torch.fx.GraphModule),
112+
("dynamo", torch.fx.GraphModule),
113+
("ts", torch.jit.ScriptModule),
114+
],
115+
)
116+
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
117+
def test_tensorrt_save_ir_type(ir, export_type):
118+
model = BoringModel()
119+
model.example_input_array = torch.randn((4, 32))
120+
121+
ret = model.to_tensorrt(ir=ir)
122+
assert isinstance(ret, export_type)
123+
124+
125+
@pytest.mark.parametrize(
126+
"output_format",
127+
["exported_program", "torchscript"],
128+
)
129+
@pytest.mark.parametrize(
130+
"ir",
131+
["default", "dynamo", "ts"],
132+
)
133+
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
134+
def test_tensorrt_export_reload(output_format, ir, tmp_path):
135+
if ir == "ts" and output_format == "exported_program":
136+
pytest.skip("TorchScript cannot be exported as exported_program")
137+
138+
import torch_tensorrt
139+
140+
model = BoringModel()
141+
model.cuda().eval()
142+
model.example_input_array = torch.ones((4, 32))
143+
144+
file_path = os.path.join(tmp_path, "model.trt")
145+
model.to_tensorrt(file_path, output_format=output_format, ir=ir)
146+
147+
loaded_model = torch_tensorrt.load(file_path)
148+
if output_format == "exported_program":
149+
loaded_model = loaded_model.module()
150+
151+
with torch.no_grad(), torch.inference_mode():
152+
model_output = model(model.example_input_array.to("cuda"))
153+
jit_output = loaded_model(model.example_input_array.to("cuda"))
154+
155+
assert torch.allclose(model_output, jit_output, rtol=1e-03, atol=1e-06)

0 commit comments

Comments
 (0)