@@ -440,7 +440,7 @@ def check_weight_equal(
440
440
except Exception :
441
441
return torch .all (sd_weight == network_weight )
442
442
443
- @needs_refit
443
+ @needs_refit # type: ignore[misc]
444
444
def _save_weight_mapping (self ) -> None :
445
445
"""
446
446
Construct the weight name mapping from engine weight name to state_dict weight name.
@@ -577,7 +577,7 @@ def _save_weight_mapping(self) -> None:
577
577
gc .collect ()
578
578
torch .cuda .empty_cache ()
579
579
580
- @needs_refit
580
+ @needs_refit # type: ignore[misc]
581
581
def _insert_engine_to_cache (self , hash_val : str , serialized_engine : bytes ) -> None :
582
582
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
583
583
# if not self.compilation_settings.strip_engine_weights:
@@ -605,7 +605,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
605
605
),
606
606
)
607
607
608
- @needs_refit
608
+ @needs_refit # type: ignore[misc]
609
609
def _pull_cached_engine (self , hash_val : str ) -> Optional [TRTInterpreterResult ]:
610
610
# query the cached TRT engine
611
611
cached_data = self .engine_cache .check (hash_val ) # type: ignore[union-attr]
@@ -941,7 +941,14 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
941
941
f"Specified output dtypes ({ len (self .output_dtypes )} ) differ from number of outputs ({ len (outputs )} )"
942
942
)
943
943
944
+ marked_outputs_ids = []
944
945
for i , output in enumerate (outputs ):
946
+ # In some cases, the same output tensor may be marked multiple times, such as _to_oppy,
947
+ # so we skip marking if the output is already marked
948
+ if id (output ) in marked_outputs_ids :
949
+ continue
950
+ marked_outputs_ids .append (id (output ))
951
+
945
952
name = f"output{ i } "
946
953
947
954
output_dtype = dtype .unknown
0 commit comments