diff --git a/mindone/transformers/models/helium/modeling_helium.py b/mindone/transformers/models/helium/modeling_helium.py index 9841d6fee2..b5200fe302 100644 --- a/mindone/transformers/models/helium/modeling_helium.py +++ b/mindone/transformers/models/helium/modeling_helium.py @@ -701,7 +701,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = attention_mask else: min_dtype = _DTYPE_2_MIN[dtype] - causal_mask = mint.full( + # FIXME: mint.full raise "TypeError: Can not convert Tensor(shape=[], dtype=BFloat16, value=-3.38953e+38) to number" in mindspore 2.6.0 and 2.7.0 + # Use ms.ops.full instead + causal_mask = ms.ops.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, diff --git a/tests/transformers_tests/models/helium/test_modeling_helium.py b/tests/transformers_tests/models/helium/test_modeling_helium.py index a14e78a0be..84bb263465 100644 --- a/tests/transformers_tests/models/helium/test_modeling_helium.py +++ b/tests/transformers_tests/models/helium/test_modeling_helium.py @@ -18,7 +18,8 @@ from tests.transformers_tests.models.modeling_common import ids_numpy DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-3} -MODES = [0, 1] +# TODO: currently only support pynative mode. Add graph mode support later. +MODES = [1] class HeliumModelTester: @@ -33,7 +34,7 @@ def __init__( vocab_size=99, hidden_size=32, num_hidden_layers=2, - num_attention_heads=4, + num_attention_heads=2, num_key_value_heads=2, intermediate_size=37, hidden_act="silu", @@ -94,6 +95,7 @@ def get_config(self): return self.config_class( vocab_size=self.vocab_size, hidden_size=self.hidden_size, + head_dim=self.head_dim, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads,