Skip to content

Conversation

wtomin
Copy link
Collaborator

@wtomin wtomin commented Aug 18, 2025

  • Edit

Fix Helium UT test under transformers==4.50.0 and mindone;

  • Before fixing:
  1. UT tests failed for graph mode;
  2. UT tests failed under pynative mode because of shape error. After investigation, the reason is that in the configuration file, num_attention_heads and num_key_value_heads must be the same. Therefore, changed num_attention_heads from 4 to 2.
  3. UT tests of BF16 failed with the following error:
test_modeling_helium.py:187:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/home/ddd/.conda/envs/ddd_ms_2.6/lib/python3.10/site-packages/mindspore/nn/cell.py:1270: in __call__
    return self.construct(*args, **kwargs)
../../../../mindone/transformers/models/helium/modeling_helium.py:547: in construct
    causal_mask = self._update_causal_mask(
../../../../mindone/transformers/models/helium/modeling_helium.py:651: in _update_causal_mask
    causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
../../../../mindone/transformers/models/helium/modeling_helium.py:704: in _prepare_4d_causal_attention_mask_with_cache_position
    causal_mask = mint.full(
/home/ddd/.conda/envs/ddd_ms_2.6/lib/python3.10/site-packages/mindspore/ops/function/array_func.py:900: in full_ext
    return fill_scalar_(size, fill_value, dtype)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = Prim[FillScalar], size = (7, 7), fill_value = Tensor(shape=[], dtype=BFloat16, value= -3.38953e+38), dtype = mindspore.bfloat

    def __call__(self, size, fill_value, dtype=None):
        # Add for jit context.
        if jit_context() and jit_context().compiled:
            return None
>       res = pyboost_fill_scalar(self, [size, fill_value, dtype if dtype is None else dtype_to_type_id('FillScalar', 'dtype', dtype
E       TypeError: Can not convert Tensor(shape=[], dtype=BFloat16, value=-3.38953e+38) to number
E
E       ----------------------------------------------------
E       - C++ Call Stack: (For framework developers)
E       ----------------------------------------------------
E       mindspore/ccsrc/pipeline/jit/ps/parse/data_converter.cc:1816 ConvertTensorToNumber

/home/ddd/.conda/envs/ddd_ms_2.6/lib/python3.10/site-packages/mindspore/ops/auto_generate/gen_ops_prim.py:13279: TypeError

It seems that mint.full is not very robust. After changing mint.full to mindspore.ops.full, this error disappeared.

  • After this Fix:
  1. All UT tests passed. Support pynative mode only;
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================ 3 passed, 58 warnings in 29.73s ===================================
  1. run generation example. It works the same as before, same as expected.
    Given the text input What is your favorite condiment?:
==> sampling, step: 0, time cost: 0.80285s
==> sampling, step: 1, time cost: 1.35132s
==> sampling, step: 2, time cost: 0.14865s, running avg speed: 6.72740token/s
==> sampling, step: 3, time cost: 0.14874s, running avg speed: 6.72536token/s
==> sampling, step: 4, time cost: 0.14833s, running avg speed: 6.73075token/s
==> sampling, step: 5, time cost: 0.14788s, running avg speed: 6.73856token/s
==> sampling, step: 6, time cost: 0.14760s, running avg speed: 6.74585token/s
==> sampling, step: 7, time cost: 0.14859s, running avg speed: 6.74323token/s
==> sampling, step: 8, time cost: 0.15055s, running avg speed: 6.72864token/s
==> sampling, step: 9, time cost: 0.15290s, running avg speed: 6.70449token/s
==> sampling, step: 10, time cost: 0.15004s, running avg speed: 6.70006token/s
==> sampling, step: 11, time cost: 0.14744s, running avg speed: 6.70821token/s
==> sampling, step: 12, time cost: 0.14789s, running avg speed: 6.71303token/s
==> sampling, step: 13, time cost: 0.14613s, running avg speed: 6.72369token/s
==> sampling, step: 14, time cost: 0.14628s, running avg speed: 6.73221token/s
==> sampling, step: 15, time cost: 0.15444s, running avg speed: 6.71316token/s
==> sampling, step: 16, time cost: 0.15005s, running avg speed: 6.70989token/s
==> sampling, step: 17, time cost: 0.14805s, running avg speed: 6.71265token/s
==> sampling, step: 18, time cost: 0.14695s, running avg speed: 6.71801token/s
==> sampling, step: 19, time cost: 0.14841s, running avg speed: 6.71913token/s
==> sampling, step: 20, time cost: 0.14672s, running avg speed: 6.72415token/s
Inference time: 13.401s
I'm not sure if this is the right place to ask this question, but I'

@wtomin wtomin requested a review from vigo999 as a code owner August 18, 2025 06:28
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @wtomin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses and resolves issues with Helium unit tests that were failing under the transformers==4.50.0 and mindone environments. The changes ensure that all unit tests now pass, specifically by enforcing pynative mode for testing and making necessary adjustments to the Helium model testing configuration. Additionally, it confirms that existing generation examples continue to function correctly after these fixes.

Highlights

  • Test Compatibility Fix: The core issue addressed is the failure of Helium unit tests when run with transformers==4.50.0 and mindone. This PR ensures these tests now pass.
  • Pynative Mode Enforcement: The test suite has been configured to run exclusively in pynative mode for Helium tests, with a future goal to re-enable graph mode support.
  • Test Configuration Updates: Adjustments were made to the HeliumModelTester configuration, specifically reducing the number of attention heads and adding a head_dim parameter, likely to resolve test failures or optimize test execution.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix unit tests for the Helium model, likely to ensure compatibility with an updated version of the transformers library. The changes involve an API update in the model implementation and several adjustments to the test configuration. While most changes seem appropriate for fixing the tests, one modification reduces test coverage by disabling the test for Grouped-Query Attention (GQA). My review highlights this issue to ensure that this feature is not left untested if it's still supported.

hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_attention_heads=2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

By setting num_attention_heads to be equal to num_key_value_heads, this test case now only covers Multi-Head Attention (MHA), where num_key_value_groups is 1. The previous configuration tested Grouped-Query Attention (GQA) with num_key_value_groups = 2.

While this change may be necessary to fix the immediate test failure, it results in the GQA implementation (specifically the n_rep > 1 path in the repeat_kv function) no longer being covered by this unit test. If GQA is a feature that should still be supported, it would be beneficial to add a separate test case or configuration to ensure its functionality is verified.

@wtomin wtomin changed the title [Bug]: Fix Helium UT tests [Bug Fix]: Fix Helium UT errors Aug 21, 2025
@wtomin wtomin added the bug Something isn't working label Sep 1, 2025
@wtomin wtomin changed the title [Bug Fix]: Fix Helium UT errors fix(transformers): Fix Helium UT errors Sep 11, 2025
else:
min_dtype = _DTYPE_2_MIN[dtype]
causal_mask = mint.full(
causal_mask = ms.ops.full(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment to explain why use ops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@vigo999
Copy link
Contributor

vigo999 commented Sep 17, 2025

same error with mindspore 2.7?

@wtomin
Copy link
Collaborator Author

wtomin commented Sep 17, 2025

same error with mindspore 2.7?

Just tested. If using mint.full instead of ms.ops.full, it raised the same error in both ms 2.6.0 and ms 2.7.0, as shown above.

@vigo999 vigo999 added this to mindone Sep 22, 2025
@vigo999 vigo999 moved this to In Progress in mindone Sep 22, 2025
@vigo999 vigo999 added this pull request to the merge queue Sep 29, 2025
Merged via the queue into mindspore-lab:master with commit e07557b Sep 29, 2025
3 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in mindone Sep 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants