-
Notifications
You must be signed in to change notification settings - Fork 88
fix(transformers): supplement key_renaming_mapping func #1216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -373,6 +373,52 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: | |||||||||||
return weights_name | ||||||||||||
|
||||||||||||
|
||||||||||||
def _find_missing_and_unexpected_keys( | ||||||||||||
cls, | ||||||||||||
model: "PreTrainedModel", | ||||||||||||
original_checkpoint_keys: List[str], | ||||||||||||
checkpoint_keys: List[str], | ||||||||||||
loading_base_model_from_task_state_dict: bool, | ||||||||||||
) -> Tuple[List[str], List[str]]: | ||||||||||||
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys | ||||||||||||
(keys found in the loaded state dict keys, but that are NOT part of the model parameters) | ||||||||||||
""" | ||||||||||||
prefix = model.base_model_prefix | ||||||||||||
|
||||||||||||
# Compute expected keys, i.e. keys that the FULL model (not model_to_load) expects | ||||||||||||
expected_keys = list(model.state_dict().keys()) | ||||||||||||
|
||||||||||||
# Adjust prefix of the keys to make them match loaded keys before removing them | ||||||||||||
missing_keys = sorted(set(expected_keys) - set(checkpoint_keys)) | ||||||||||||
unexpected_keys = set(checkpoint_keys) - set(expected_keys) | ||||||||||||
# If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys | ||||||||||||
if loading_base_model_from_task_state_dict: | ||||||||||||
task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")] | ||||||||||||
unexpected_keys.update(task_specific_keys) | ||||||||||||
|
||||||||||||
# Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but | ||||||||||||
# may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway | ||||||||||||
model_buffers = {n for n, _ in model.named_buffers()} | ||||||||||||
unexpected_keys = sorted(unexpected_keys - model_buffers) | ||||||||||||
|
||||||||||||
# Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model | ||||||||||||
# (so the buffer name has changed). Remove them in such a case | ||||||||||||
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers) | ||||||||||||
if has_inv_freq_buffers: | ||||||||||||
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] | ||||||||||||
|
||||||||||||
# Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...) | ||||||||||||
if cls._keys_to_ignore_on_load_missing is not None: | ||||||||||||
for pattern in cls._keys_to_ignore_on_load_missing: | ||||||||||||
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None] | ||||||||||||
|
||||||||||||
if cls._keys_to_ignore_on_load_unexpected is not None: | ||||||||||||
for pattern in cls._keys_to_ignore_on_load_unexpected: | ||||||||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None] | ||||||||||||
Comment on lines
+416
to
+417
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the filtering of
Suggested change
|
||||||||||||
|
||||||||||||
return missing_keys, unexpected_keys | ||||||||||||
|
||||||||||||
|
||||||||||||
class ModuleUtilsMixin: | ||||||||||||
""" | ||||||||||||
A few utilities for `mindspore.nn.Cell`, to be used as a mixin. | ||||||||||||
|
@@ -2438,6 +2484,83 @@ def from_pretrained( | |||||||||||
|
||||||||||||
return model | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]: | ||||||||||||
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" | ||||||||||||
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) | ||||||||||||
# This rename is logged. | ||||||||||||
if key.endswith("LayerNorm.beta"): | ||||||||||||
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True | ||||||||||||
if key.endswith("LayerNorm.gamma"): | ||||||||||||
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True | ||||||||||||
|
||||||||||||
return key, False | ||||||||||||
|
||||||||||||
def _get_key_renaming_mapping( | ||||||||||||
self, | ||||||||||||
checkpoint_keys: List[str], | ||||||||||||
key_mapping: Optional[Dict[str, str]] = None, | ||||||||||||
loading_base_model_from_task_state_dict: bool = False, | ||||||||||||
loading_task_model_from_base_state_dict: bool = False, | ||||||||||||
): | ||||||||||||
""" | ||||||||||||
Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model | ||||||||||||
that we are loading expects. This is the single entry point for key renaming that will be used during | ||||||||||||
loading. | ||||||||||||
Log if any parameters have been renamed. | ||||||||||||
""" | ||||||||||||
prefix = self.base_model_prefix | ||||||||||||
_prefix = f"{prefix}." | ||||||||||||
|
||||||||||||
renamed_keys = {} | ||||||||||||
key_renaming_mapping = {} | ||||||||||||
for key in checkpoint_keys: | ||||||||||||
# Class specific rename | ||||||||||||
new_key, has_changed = self._fix_state_dict_key_on_load(key) | ||||||||||||
|
||||||||||||
# Optionally map the key according to `key_mapping` | ||||||||||||
if key_mapping is not None: | ||||||||||||
for pattern, replacement in key_mapping.items(): | ||||||||||||
new_key, n_replace = re.subn(pattern, replacement, new_key) | ||||||||||||
# Early exit of the loop | ||||||||||||
if n_replace > 0: | ||||||||||||
has_changed = True | ||||||||||||
break | ||||||||||||
|
||||||||||||
# In this case, we need to add the prefix to the keys, to match them to the expected keys | ||||||||||||
if loading_task_model_from_base_state_dict: | ||||||||||||
new_key = ".".join([prefix, new_key]) | ||||||||||||
key = ".".join([prefix, key]) | ||||||||||||
# In this case we need to remove the prefix from the key to match them to the expected keys, and use | ||||||||||||
# only the keys starting with the prefix | ||||||||||||
elif loading_base_model_from_task_state_dict: | ||||||||||||
if not new_key.startswith(_prefix): | ||||||||||||
continue | ||||||||||||
new_key = new_key[len(_prefix) :] | ||||||||||||
key = key[len(_prefix) :] | ||||||||||||
|
||||||||||||
if not has_changed: | ||||||||||||
key_renaming_mapping[new_key] = new_key | ||||||||||||
else: | ||||||||||||
key_renaming_mapping[key] = new_key | ||||||||||||
|
||||||||||||
# track gamma/beta rename for logging | ||||||||||||
if has_changed: | ||||||||||||
if key.endswith("LayerNorm.gamma"): | ||||||||||||
renamed_keys["LayerNorm.gamma"] = (key, new_key) | ||||||||||||
elif key.endswith("LayerNorm.beta"): | ||||||||||||
renamed_keys["LayerNorm.beta"] = (key, new_key) | ||||||||||||
|
||||||||||||
Comment on lines
+2517
to
+2553
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic in this loop is hard to follow and contains a bug. The loop variable for key_from_checkpoint in checkpoint_keys:
# Class specific rename
new_key, has_changed = self._fix_state_dict_key_on_load(key_from_checkpoint)
# Optionally map the key according to `key_mapping`
if key_mapping is not None:
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
if n_replace > 0:
has_changed = True
break
key_in_model = new_key
# In this case, we need to add the prefix to the keys, to match them to the expected keys
if loading_task_model_from_base_state_dict:
key_in_model = ".".join([prefix, key_in_model])
# In this case we need to remove the prefix from the key to match them to the expected keys, and use
# only the keys starting with the prefix
elif loading_base_model_from_task_state_dict:
if not key_in_model.startswith(_prefix):
continue
key_in_model = key_in_model[len(_prefix) :]
key_renaming_mapping[key_from_checkpoint] = key_in_model
# track gamma/beta rename for logging
if has_changed:
if key_from_checkpoint.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key_from_checkpoint, new_key)
elif key_from_checkpoint.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key_from_checkpoint, new_key) |
||||||||||||
if renamed_keys: | ||||||||||||
warning_msg = f"A pretrained model of type `{self.__class__.__name__}` " | ||||||||||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" | ||||||||||||
for old_key, new_key in renamed_keys.values(): | ||||||||||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n" | ||||||||||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." | ||||||||||||
logger.info_once(warning_msg) | ||||||||||||
|
||||||||||||
return key_renaming_mapping | ||||||||||||
|
||||||||||||
@classmethod | ||||||||||||
def _load_pretrained_model( | ||||||||||||
cls, | ||||||||||||
|
@@ -2450,71 +2573,79 @@ def _load_pretrained_model( | |||||||||||
sharded_metadata=None, | ||||||||||||
dtype=None, | ||||||||||||
keep_in_fp32_modules=None, | ||||||||||||
key_mapping: Optional[Dict[str, str]] = None, | ||||||||||||
weights_only: bool = True, | ||||||||||||
): | ||||||||||||
model.tie_weights() | ||||||||||||
|
||||||||||||
# Retrieve missing & unexpected_keys | ||||||||||||
model_state_dict = {k: v for k, v in model.parameters_and_names()} | ||||||||||||
expected_keys = list(model_state_dict.keys()) | ||||||||||||
prefix = model.base_model_prefix | ||||||||||||
original_loaded_keys = loaded_keys | ||||||||||||
|
||||||||||||
if len(prefix) > 0: | ||||||||||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) | ||||||||||||
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) | ||||||||||||
# Get all the keys of the state dicts that we have to initialize the model | ||||||||||||
if sharded_metadata is not None: | ||||||||||||
original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"] | ||||||||||||
elif state_dict is not None: | ||||||||||||
original_checkpoint_keys = list(state_dict.keys()) | ||||||||||||
else: | ||||||||||||
has_prefix_module = False | ||||||||||||
expects_prefix_module = False | ||||||||||||
|
||||||||||||
# Mapping loaded_keys from pt to ms | ||||||||||||
pt2ms_mappings = _get_pt2ms_mappings(model) | ||||||||||||
loaded_keys = _get_pt2ms_mapped_k(pt2ms_mappings, has_prefix_module, expects_prefix_module, loaded_keys, prefix) | ||||||||||||
|
||||||||||||
# key re-naming operations are never done on the keys | ||||||||||||
# that are loaded, but always on the keys of the newly initialized model | ||||||||||||
remove_prefix_from_model = not has_prefix_module and expects_prefix_module | ||||||||||||
add_prefix_to_model = has_prefix_module and not expects_prefix_module | ||||||||||||
|
||||||||||||
if remove_prefix_from_model: | ||||||||||||
_prefix = f"{prefix}." | ||||||||||||
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] | ||||||||||||
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] | ||||||||||||
elif add_prefix_to_model: | ||||||||||||
expected_keys = [".".join([prefix, s]) for s in expected_keys] | ||||||||||||
|
||||||||||||
missing_keys = sorted(set(expected_keys) - set(loaded_keys)) | ||||||||||||
unexpected_keys = set(loaded_keys) - set(expected_keys) | ||||||||||||
|
||||||||||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning | ||||||||||||
# the user. | ||||||||||||
if cls._keys_to_ignore_on_load_missing is not None: | ||||||||||||
for pat in cls._keys_to_ignore_on_load_missing: | ||||||||||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] | ||||||||||||
|
||||||||||||
if cls._keys_to_ignore_on_load_unexpected is not None: | ||||||||||||
for pat in cls._keys_to_ignore_on_load_unexpected: | ||||||||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | ||||||||||||
original_checkpoint_keys = list(load_state_dict(pretrained_model_name_or_path).keys()) | ||||||||||||
|
||||||||||||
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture | ||||||||||||
prefix = model.base_model_prefix | ||||||||||||
_prefix = f"{prefix}." | ||||||||||||
has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False | ||||||||||||
expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False | ||||||||||||
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module | ||||||||||||
loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module | ||||||||||||
|
||||||||||||
# Find the key names that the model expects from the serialized keys | ||||||||||||
key_renaming_mapping = model._get_key_renaming_mapping( | ||||||||||||
original_checkpoint_keys, | ||||||||||||
key_mapping, | ||||||||||||
loading_base_model_from_task_state_dict, | ||||||||||||
loading_task_model_from_base_state_dict, | ||||||||||||
) | ||||||||||||
checkpoint_keys = list(key_renaming_mapping.values()) | ||||||||||||
|
||||||||||||
# Find missing and unexpected keys from the state dict | ||||||||||||
missing_keys, unexpected_keys = _find_missing_and_unexpected_keys( | ||||||||||||
cls, | ||||||||||||
model, | ||||||||||||
original_checkpoint_keys, | ||||||||||||
checkpoint_keys, | ||||||||||||
loading_base_model_from_task_state_dict, | ||||||||||||
) | ||||||||||||
|
||||||||||||
# Set some modules to fp32 if any | ||||||||||||
if keep_in_fp32_modules is not None: | ||||||||||||
for name, param in model.parameters_and_names(): | ||||||||||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): | ||||||||||||
param.set_dtype(ms.float32) | ||||||||||||
|
||||||||||||
# Make sure we are able to load base models as well as derived models (with heads) | ||||||||||||
start_prefix = "" | ||||||||||||
# Make sure we are able to load base models as well as derived models (specific task models, with heads) | ||||||||||||
model_to_load = model | ||||||||||||
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: | ||||||||||||
start_prefix = cls.base_model_prefix + "." | ||||||||||||
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: | ||||||||||||
model_to_load = getattr(model, cls.base_model_prefix) | ||||||||||||
base_model_expected_keys = list(k for k, v in model_to_load.parameters_and_names()) | ||||||||||||
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): | ||||||||||||
# In this case, we load a ForTaskModel with keys from a BaseModel -> only load keys to the BaseModel | ||||||||||||
if loading_task_model_from_base_state_dict: | ||||||||||||
model_to_load = getattr(model, prefix) | ||||||||||||
# Here we need to remove the prefix we added to correctly find missing/unexpected keys, as we will load | ||||||||||||
# in the submodule | ||||||||||||
key_renaming_mapping = {k: v[len(_prefix) :] for k, v in key_renaming_mapping.items()} | ||||||||||||
checkpoint_keys = list(key_renaming_mapping.values()) | ||||||||||||
# small sanity check: the base model should not contain task-specific head keys | ||||||||||||
task_specific_expected_keys = [s for s in model.state_dict().keys() if not s.startswith(_prefix)] | ||||||||||||
base_model_expected_keys = list(model_to_load.state_dict().keys()) | ||||||||||||
if any( | ||||||||||||
key in task_specific_expected_keys and key not in base_model_expected_keys for key in checkpoint_keys | ||||||||||||
): | ||||||||||||
raise ValueError( | ||||||||||||
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was " | ||||||||||||
"properly saved?" | ||||||||||||
) | ||||||||||||
|
||||||||||||
# Make sure we are able to load base models as well as derived models (with heads) | ||||||||||||
start_prefix = "" | ||||||||||||
model_to_load = model | ||||||||||||
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: | ||||||||||||
start_prefix = cls.base_model_prefix + "." | ||||||||||||
|
||||||||||||
Comment on lines
+2643
to
+2648
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block of code appears to be a remnant of the previous implementation and is now redundant. The logic for handling model prefixes is encapsulated within the new |
||||||||||||
def _find_mismatched_keys( | ||||||||||||
state_dict, | ||||||||||||
model_state_dict, | ||||||||||||
|
@@ -2559,12 +2690,17 @@ def _find_mismatched_keys( | |||||||||||
# Whole checkpoint | ||||||||||||
state_dict = _convert_state_dict(model, state_dict, prefix) | ||||||||||||
|
||||||||||||
matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s] | ||||||||||||
if matching: | ||||||||||||
# Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta | ||||||||||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} | ||||||||||||
|
||||||||||||
mismatched_keys = _find_mismatched_keys( | ||||||||||||
state_dict, | ||||||||||||
model_state_dict, | ||||||||||||
original_loaded_keys, | ||||||||||||
add_prefix_to_model, | ||||||||||||
remove_prefix_from_model, | ||||||||||||
loading_task_model_from_base_state_dict, | ||||||||||||
loading_base_model_from_task_state_dict, | ||||||||||||
ignore_mismatched_sizes, | ||||||||||||
) | ||||||||||||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_sharded=False) | ||||||||||||
|
@@ -2586,14 +2722,21 @@ def _find_mismatched_keys( | |||||||||||
state_dict = load_state_dict(shard_file) | ||||||||||||
state_dict = _convert_state_dict(model, state_dict, prefix) | ||||||||||||
|
||||||||||||
matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s] | ||||||||||||
if matching: | ||||||||||||
# Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta | ||||||||||||
state_dict = { | ||||||||||||
key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping | ||||||||||||
} | ||||||||||||
|
||||||||||||
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | ||||||||||||
# matching the weights in the model. | ||||||||||||
mismatched_keys += _find_mismatched_keys( | ||||||||||||
state_dict, | ||||||||||||
model_state_dict, | ||||||||||||
original_loaded_keys, | ||||||||||||
add_prefix_to_model, | ||||||||||||
remove_prefix_from_model, | ||||||||||||
loading_task_model_from_base_state_dict, | ||||||||||||
loading_base_model_from_task_state_dict, | ||||||||||||
ignore_mismatched_sizes, | ||||||||||||
) | ||||||||||||
|
||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for filtering
missing_keys
re-creates the list in each iteration of the loop, which is inefficient for a large number of patterns. This can be optimized by using a single list comprehension.