@@ -401,10 +401,10 @@ def attribute(
401
401
if attr_progress is not None :
402
402
attr_progress .close ()
403
403
404
- # pyre-fixme[7]: Expected `Variable[TensorOrTupleOfTensorsGeneric <:
405
- # [Tensor, typing.Tuple[Tensor, ...]]]`
406
- # but got `Union[Tensor, typing.Tuple[Tensor, ...]]`.
407
- return self . _generate_result ( total_attrib , weights , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
404
+ return cast (
405
+ TensorOrTupleOfTensorsGeneric ,
406
+ self . _generate_result ( total_attrib , weights , is_inputs_tuple ),
407
+ )
408
408
409
409
def _attribute_with_independent_feature_masks (
410
410
self ,
@@ -629,8 +629,7 @@ def _should_skip_inputs_and_warn(
629
629
all_empty = False
630
630
if self ._min_examples_per_batch_grouped is not None and (
631
631
formatted_inputs [tensor_idx ].shape [0 ]
632
- # pyre-ignore[58]: Type has been narrowed to int
633
- < self ._min_examples_per_batch_grouped
632
+ < cast (int , self ._min_examples_per_batch_grouped )
634
633
):
635
634
should_skip = True
636
635
break
@@ -789,35 +788,35 @@ def attribute_future(
789
788
)
790
789
791
790
if enable_cross_tensor_attribution :
792
- # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
793
- # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
794
- # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
795
- return self . _attribute_with_cross_tensor_feature_masks_future ( # type: ignore # noqa: E501 line too long
796
- formatted_inputs = formatted_inputs ,
797
- formatted_additional_forward_args = formatted_additional_forward_args ,
798
- target = target ,
799
- baselines = baselines ,
800
- formatted_feature_mask = formatted_feature_mask ,
801
- attr_progress = attr_progress ,
802
- processed_initial_eval_fut = processed_initial_eval_fut ,
803
- is_inputs_tuple = is_inputs_tuple ,
804
- perturbations_per_eval = perturbations_per_eval ,
791
+ return cast (
792
+ Future [ TensorOrTupleOfTensorsGeneric ],
793
+ self . _attribute_with_cross_tensor_feature_masks_future ( # type: ignore # noqa: E501 line too long
794
+ formatted_inputs = formatted_inputs ,
795
+ formatted_additional_forward_args = formatted_additional_forward_args ,
796
+ target = target ,
797
+ baselines = baselines ,
798
+ formatted_feature_mask = formatted_feature_mask ,
799
+ attr_progress = attr_progress ,
800
+ processed_initial_eval_fut = processed_initial_eval_fut ,
801
+ is_inputs_tuple = is_inputs_tuple ,
802
+ perturbations_per_eval = perturbations_per_eval ,
803
+ ) ,
805
804
)
806
805
else :
807
- # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
808
- # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
809
- # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
810
- return self . _attribute_with_independent_feature_masks_future ( # type: ignore # noqa: E501 line too long
811
- formatted_inputs ,
812
- formatted_additional_forward_args ,
813
- target ,
814
- baselines ,
815
- formatted_feature_mask ,
816
- perturbations_per_eval ,
817
- attr_progress ,
818
- processed_initial_eval_fut ,
819
- is_inputs_tuple ,
820
- ** kwargs ,
806
+ return cast (
807
+ Future [ TensorOrTupleOfTensorsGeneric ],
808
+ self . _attribute_with_independent_feature_masks_future ( # type: ignore # noqa: E501 line too long
809
+ formatted_inputs ,
810
+ formatted_additional_forward_args ,
811
+ target ,
812
+ baselines ,
813
+ formatted_feature_mask ,
814
+ perturbations_per_eval ,
815
+ attr_progress ,
816
+ processed_initial_eval_fut ,
817
+ is_inputs_tuple ,
818
+ ** kwargs ,
819
+ ) ,
821
820
)
822
821
823
822
def _attribute_with_independent_feature_masks_future (
0 commit comments