-
Notifications
You must be signed in to change notification settings - Fork 88
fix(transformers): supplement key_renaming_mapping func #1216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(transformers): supplement key_renaming_mapping func #1216
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @wcrzlh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a weight loading bug in the Ernie model by introducing a robust key renaming mechanism within the mindone.transformers
library. Specifically, it resolves issues where LayerNorm parameters (gamma
and beta
) are named differently across various PyTorch versions, ensuring models can be loaded correctly regardless of these naming conventions. The changes involve adding new utility functions for key mapping and integrating them into the model loading process, enhancing the overall flexibility and compatibility of model loading.
Highlights
- Enhanced Key Renaming for Model Loading: Implemented a new
_get_key_renaming_mapping
function to dynamically map checkpoint keys to the model's expected parameter names, improving compatibility across different model versions and frameworks. - LayerNorm Parameter Name Normalization: Introduced
_fix_state_dict_key_on_load
to specifically handle the renaming ofLayerNorm.gamma
toLayerNorm.weight
andLayerNorm.beta
toLayerNorm.bias
, directly addressing a known loading issue for models like Ernie. - Improved Missing/Unexpected Key Detection: Added
_find_missing_and_unexpected_keys
to provide more accurate and robust reporting of discrepancies between loaded state dictionaries and the model's expected parameters. - Refactored Pretrained Model Loading: Updated the
_load_pretrained_model
method to seamlessly integrate the new key renaming and key detection logic, making the model loading process more resilient to variations in parameter naming conventions.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a key renaming mechanism to handle differences in LayerNorm
parameter names across PyTorch versions, which is a valuable addition. The refactoring of key handling logic into separate functions improves modularity. However, I've identified a critical bug in the new _get_key_renaming_mapping
function where the loop variable is modified, leading to incorrect behavior. I've also found some opportunities for efficiency improvements and removal of redundant code. Please review my comments for details.
for key in checkpoint_keys: | ||
# Class specific rename | ||
new_key, has_changed = self._fix_state_dict_key_on_load(key) | ||
|
||
# Optionally map the key according to `key_mapping` | ||
if key_mapping is not None: | ||
for pattern, replacement in key_mapping.items(): | ||
new_key, n_replace = re.subn(pattern, replacement, new_key) | ||
# Early exit of the loop | ||
if n_replace > 0: | ||
has_changed = True | ||
break | ||
|
||
# In this case, we need to add the prefix to the keys, to match them to the expected keys | ||
if loading_task_model_from_base_state_dict: | ||
new_key = ".".join([prefix, new_key]) | ||
key = ".".join([prefix, key]) | ||
# In this case we need to remove the prefix from the key to match them to the expected keys, and use | ||
# only the keys starting with the prefix | ||
elif loading_base_model_from_task_state_dict: | ||
if not new_key.startswith(_prefix): | ||
continue | ||
new_key = new_key[len(_prefix) :] | ||
key = key[len(_prefix) :] | ||
|
||
if not has_changed: | ||
key_renaming_mapping[new_key] = new_key | ||
else: | ||
key_renaming_mapping[key] = new_key | ||
|
||
# track gamma/beta rename for logging | ||
if has_changed: | ||
if key.endswith("LayerNorm.gamma"): | ||
renamed_keys["LayerNorm.gamma"] = (key, new_key) | ||
elif key.endswith("LayerNorm.beta"): | ||
renamed_keys["LayerNorm.beta"] = (key, new_key) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in this loop is hard to follow and contains a bug. The loop variable key
is modified during prefix handling, which causes the key_renaming_mapping
to be populated with incorrect keys. This will lead to failures when the mapping is used to rename keys in the state dictionary. The logic for populating key_renaming_mapping
is also unnecessarily complex and can be simplified.
for key_from_checkpoint in checkpoint_keys:
# Class specific rename
new_key, has_changed = self._fix_state_dict_key_on_load(key_from_checkpoint)
# Optionally map the key according to `key_mapping`
if key_mapping is not None:
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
if n_replace > 0:
has_changed = True
break
key_in_model = new_key
# In this case, we need to add the prefix to the keys, to match them to the expected keys
if loading_task_model_from_base_state_dict:
key_in_model = ".".join([prefix, key_in_model])
# In this case we need to remove the prefix from the key to match them to the expected keys, and use
# only the keys starting with the prefix
elif loading_base_model_from_task_state_dict:
if not key_in_model.startswith(_prefix):
continue
key_in_model = key_in_model[len(_prefix) :]
key_renaming_mapping[key_from_checkpoint] = key_in_model
# track gamma/beta rename for logging
if has_changed:
if key_from_checkpoint.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key_from_checkpoint, new_key)
elif key_from_checkpoint.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key_from_checkpoint, new_key)
# Make sure we are able to load base models as well as derived models (with heads) | ||
start_prefix = "" | ||
model_to_load = model | ||
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: | ||
start_prefix = cls.base_model_prefix + "." | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code appears to be a remnant of the previous implementation and is now redundant. The logic for handling model prefixes is encapsulated within the new _get_key_renaming_mapping
function. Furthermore, line 2645 model_to_load = model
introduces a bug by resetting model_to_load
, which might have been correctly set to a submodule earlier in the function. Since start_prefix
is also no longer used, this entire block can be safely removed.
for pattern in cls._keys_to_ignore_on_load_missing: | ||
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for filtering missing_keys
re-creates the list in each iteration of the loop, which is inefficient for a large number of patterns. This can be optimized by using a single list comprehension.
for pattern in cls._keys_to_ignore_on_load_missing: | |
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None] | |
missing_keys = [ | |
k for k in missing_keys if all(re.search(p, k) is None for p in cls._keys_to_ignore_on_load_missing) | |
] |
for pattern in cls._keys_to_ignore_on_load_unexpected: | ||
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the filtering of missing_keys
, the filtering of unexpected_keys
can be made more efficient by avoiding list re-creation in a loop.
for pattern in cls._keys_to_ignore_on_load_unexpected: | |
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None] | |
unexpected_keys = [ | |
k for k in unexpected_keys if all(re.search(p, k) is None for p in cls._keys_to_ignore_on_load_unexpected) | |
] |
ef41aa8
to
3cb1756
Compare
What does this PR do?
Fixes # (issue)
This pr is to solve loading_weight bug for ernie model. LayerNorm weight in this model has different model name (gamma-->weight/ beta --> bias) according to torch version. So key renaming mapping func should be implemented for this convertion.
Before submitting
What's New
. Here are thedocumentation guidelines
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@xxx