Skip to content

Commit cf9e7bd

Browse files
committed
fix identity issue
1 parent d1d18b9 commit cf9e7bd

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def check_weight_equal(
440440
except Exception:
441441
return torch.all(sd_weight == network_weight)
442442

443-
@needs_refit
443+
@needs_refit # type: ignore[misc]
444444
def _save_weight_mapping(self) -> None:
445445
"""
446446
Construct the weight name mapping from engine weight name to state_dict weight name.
@@ -577,7 +577,7 @@ def _save_weight_mapping(self) -> None:
577577
gc.collect()
578578
torch.cuda.empty_cache()
579579

580-
@needs_refit
580+
@needs_refit # type: ignore[misc]
581581
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
582582
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
583583
# if not self.compilation_settings.strip_engine_weights:
@@ -605,7 +605,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
605605
),
606606
)
607607

608-
@needs_refit
608+
@needs_refit # type: ignore[misc]
609609
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
610610
# query the cached TRT engine
611611
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
@@ -941,7 +941,14 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
941941
f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})"
942942
)
943943

944+
marked_outputs_ids = []
944945
for i, output in enumerate(outputs):
946+
# In some cases, the same output tensor may be marked multiple times, such as _to_oppy,
947+
# so we skip marking if the output is already marked
948+
if id(output) in marked_outputs_ids:
949+
continue
950+
marked_outputs_ids.append(id(output))
951+
945952
name = f"output{i}"
946953

947954
output_dtype = dtype.unknown

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,7 @@ def aten_ops_clone_copy_placeholder(
11231123
name,
11241124
args[0],
11251125
kwargs.get("dtype", args[0].dtype),
1126-
force_layer=True,
1126+
force_layer=False,
11271127
)
11281128

11291129

@@ -1226,7 +1226,7 @@ def aten_ops_sum(
12261226
name,
12271227
sum_,
12281228
kwargs["output_dtype"],
1229-
force_layer=True,
1229+
force_layer=False,
12301230
)
12311231
else:
12321232
return sum_
@@ -3229,7 +3229,7 @@ def aten_ops_copy(
32293229
name,
32303230
src,
32313231
src.dtype,
3232-
force_layer=True,
3232+
force_layer=False,
32333233
)
32343234

32353235

0 commit comments

Comments
 (0)