Skip to content

Conversation

wcrzlh
Copy link
Contributor

@wcrzlh wcrzlh commented Aug 15, 2025

What does this PR do?

Fixes # (issue)
This pr is to solve loading_weight bug for ernie model. LayerNorm weight in this model has different model name (gamma-->weight/ beta --> bias) according to torch version. So key renaming mapping func should be implemented for this convertion.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

@wcrzlh wcrzlh requested a review from vigo999 as a code owner August 15, 2025 08:33
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @wcrzlh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a weight loading bug in the Ernie model by introducing a robust key renaming mechanism within the mindone.transformers library. Specifically, it resolves issues where LayerNorm parameters (gamma and beta) are named differently across various PyTorch versions, ensuring models can be loaded correctly regardless of these naming conventions. The changes involve adding new utility functions for key mapping and integrating them into the model loading process, enhancing the overall flexibility and compatibility of model loading.

Highlights

  • Enhanced Key Renaming for Model Loading: Implemented a new _get_key_renaming_mapping function to dynamically map checkpoint keys to the model's expected parameter names, improving compatibility across different model versions and frameworks.
  • LayerNorm Parameter Name Normalization: Introduced _fix_state_dict_key_on_load to specifically handle the renaming of LayerNorm.gamma to LayerNorm.weight and LayerNorm.beta to LayerNorm.bias, directly addressing a known loading issue for models like Ernie.
  • Improved Missing/Unexpected Key Detection: Added _find_missing_and_unexpected_keys to provide more accurate and robust reporting of discrepancies between loaded state dictionaries and the model's expected parameters.
  • Refactored Pretrained Model Loading: Updated the _load_pretrained_model method to seamlessly integrate the new key renaming and key detection logic, making the model loading process more resilient to variations in parameter naming conventions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a key renaming mechanism to handle differences in LayerNorm parameter names across PyTorch versions, which is a valuable addition. The refactoring of key handling logic into separate functions improves modularity. However, I've identified a critical bug in the new _get_key_renaming_mapping function where the loop variable is modified, leading to incorrect behavior. I've also found some opportunities for efficiency improvements and removal of redundant code. Please review my comments for details.

Comment on lines +2517 to +2553
for key in checkpoint_keys:
# Class specific rename
new_key, has_changed = self._fix_state_dict_key_on_load(key)

# Optionally map the key according to `key_mapping`
if key_mapping is not None:
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
if n_replace > 0:
has_changed = True
break

# In this case, we need to add the prefix to the keys, to match them to the expected keys
if loading_task_model_from_base_state_dict:
new_key = ".".join([prefix, new_key])
key = ".".join([prefix, key])
# In this case we need to remove the prefix from the key to match them to the expected keys, and use
# only the keys starting with the prefix
elif loading_base_model_from_task_state_dict:
if not new_key.startswith(_prefix):
continue
new_key = new_key[len(_prefix) :]
key = key[len(_prefix) :]

if not has_changed:
key_renaming_mapping[new_key] = new_key
else:
key_renaming_mapping[key] = new_key

# track gamma/beta rename for logging
if has_changed:
if key.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key, new_key)
elif key.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key, new_key)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic in this loop is hard to follow and contains a bug. The loop variable key is modified during prefix handling, which causes the key_renaming_mapping to be populated with incorrect keys. This will lead to failures when the mapping is used to rename keys in the state dictionary. The logic for populating key_renaming_mapping is also unnecessarily complex and can be simplified.

        for key_from_checkpoint in checkpoint_keys:
            # Class specific rename
            new_key, has_changed = self._fix_state_dict_key_on_load(key_from_checkpoint)

            # Optionally map the key according to `key_mapping`
            if key_mapping is not None:
                for pattern, replacement in key_mapping.items():
                    new_key, n_replace = re.subn(pattern, replacement, new_key)
                    # Early exit of the loop
                    if n_replace > 0:
                        has_changed = True
                        break

            key_in_model = new_key
            # In this case, we need to add the prefix to the keys, to match them to the expected keys
            if loading_task_model_from_base_state_dict:
                key_in_model = ".".join([prefix, key_in_model])
            # In this case we need to remove the prefix from the key to match them to the expected keys, and use
            # only the keys starting with the prefix
            elif loading_base_model_from_task_state_dict:
                if not key_in_model.startswith(_prefix):
                    continue
                key_in_model = key_in_model[len(_prefix) :]

            key_renaming_mapping[key_from_checkpoint] = key_in_model

            # track gamma/beta rename for logging
            if has_changed:
                if key_from_checkpoint.endswith("LayerNorm.gamma"):
                    renamed_keys["LayerNorm.gamma"] = (key_from_checkpoint, new_key)
                elif key_from_checkpoint.endswith("LayerNorm.beta"):
                    renamed_keys["LayerNorm.beta"] = (key_from_checkpoint, new_key)

Comment on lines +2643 to +2648
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code appears to be a remnant of the previous implementation and is now redundant. The logic for handling model prefixes is encapsulated within the new _get_key_renaming_mapping function. Furthermore, line 2645 model_to_load = model introduces a bug by resetting model_to_load, which might have been correctly set to a submodule earlier in the function. Since start_prefix is also no longer used, this entire block can be safely removed.

Comment on lines +412 to +413
for pattern in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for filtering missing_keys re-creates the list in each iteration of the loop, which is inefficient for a large number of patterns. This can be optimized by using a single list comprehension.

Suggested change
for pattern in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None]
missing_keys = [
k for k in missing_keys if all(re.search(p, k) is None for p in cls._keys_to_ignore_on_load_missing)
]

Comment on lines +416 to +417
for pattern in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the filtering of missing_keys, the filtering of unexpected_keys can be made more efficient by avoiding list re-creation in a loop.

Suggested change
for pattern in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]
unexpected_keys = [
k for k in unexpected_keys if all(re.search(p, k) is None for p in cls._keys_to_ignore_on_load_unexpected)
]

@wcrzlh wcrzlh force-pushed the refactor_load_pretrained_model branch from ef41aa8 to 3cb1756 Compare August 18, 2025 01:52
@zhanghuiyao zhanghuiyao added this pull request to the merge queue Sep 4, 2025
Merged via the queue into mindspore-lab:master with commit 0c2f638 Sep 4, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants