diff --git a/tests/transformers_tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/transformers_tests/models/seamless_m4t/test_modeling_seamless_m4t.py index b31ee36d0a..cdf915ee1d 100644 --- a/tests/transformers_tests/models/seamless_m4t/test_modeling_seamless_m4t.py +++ b/tests/transformers_tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -230,6 +230,7 @@ def get_config(self): inputs_dict_speech["labels"], ), {}, + [1, 2, 3], { "logits": 0, "encoder_last_hidden_state": 2, @@ -248,6 +249,7 @@ def get_config(self): inputs_dict_speech["labels"], ), {}, + [1, 2, 3], { "logits": 0, "encoder_last_hidden_state": 2, @@ -266,6 +268,7 @@ def get_config(self): inputs_dict_text["labels"], ), {}, + [0, 1, 2, 3], { "logits": 0, "encoder_last_hidden_state": 2, @@ -284,6 +287,7 @@ def get_config(self): inputs_dict_text["labels"], ), {}, + [0, 1, 2, 3], { "logits": 0, "encoder_last_hidden_state": 2, @@ -303,6 +307,7 @@ def get_config(self): inputs_dict_text["labels"], ), {}, + [0, 2, 3, 4], { "logits": 0, "encoder_last_hidden_state": 2, @@ -312,7 +317,7 @@ def get_config(self): @pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,inputs_type_idx,outputs_map,dtype,mode", [ case + [ @@ -334,6 +339,7 @@ def test_named_modules( init_kwargs, inputs_args, inputs_kwargs, + inputs_type_idx, outputs_map, dtype, mode, @@ -357,6 +363,16 @@ def test_named_modules( ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) ms_inputs_kwargs["return_dict"] = False + pt_inputs_args = tuple( + tensor.long() if i in inputs_type_idx else tensor.to(PT_DTYPE_MAPPING[pt_dtype]) + for i, tensor in enumerate(pt_inputs_args) + ) + + ms_inputs_args = tuple( + tensor.to(ms.int64) if i in inputs_type_idx else tensor.to(MS_DTYPE_MAPPING[ms_dtype]) + for i, tensor in enumerate(ms_inputs_args) + ) + with torch.no_grad(): pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs)