Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 193 additions & 50 deletions mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,52 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name


def _find_missing_and_unexpected_keys(
cls,
model: "PreTrainedModel",
original_checkpoint_keys: List[str],
checkpoint_keys: List[str],
loading_base_model_from_task_state_dict: bool,
) -> Tuple[List[str], List[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
"""
prefix = model.base_model_prefix

# Compute expected keys, i.e. keys that the FULL model (not model_to_load) expects
expected_keys = list(model.state_dict().keys())

# Adjust prefix of the keys to make them match loaded keys before removing them
missing_keys = sorted(set(expected_keys) - set(checkpoint_keys))
unexpected_keys = set(checkpoint_keys) - set(expected_keys)
# If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys
if loading_base_model_from_task_state_dict:
task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")]
unexpected_keys.update(task_specific_keys)

# Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but
# may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway
model_buffers = {n for n, _ in model.named_buffers()}
unexpected_keys = sorted(unexpected_keys - model_buffers)

# Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model
# (so the buffer name has changed). Remove them in such a case
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
if has_inv_freq_buffers:
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]

# Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...)
if cls._keys_to_ignore_on_load_missing is not None:
for pattern in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None]
Comment on lines +412 to +413

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for filtering missing_keys re-creates the list in each iteration of the loop, which is inefficient for a large number of patterns. This can be optimized by using a single list comprehension.

Suggested change
for pattern in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None]
missing_keys = [
k for k in missing_keys if all(re.search(p, k) is None for p in cls._keys_to_ignore_on_load_missing)
]


if cls._keys_to_ignore_on_load_unexpected is not None:
for pattern in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]
Comment on lines +416 to +417

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the filtering of missing_keys, the filtering of unexpected_keys can be made more efficient by avoiding list re-creation in a loop.

Suggested change
for pattern in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]
unexpected_keys = [
k for k in unexpected_keys if all(re.search(p, k) is None for p in cls._keys_to_ignore_on_load_unexpected)
]


return missing_keys, unexpected_keys


class ModuleUtilsMixin:
"""
A few utilities for `mindspore.nn.Cell`, to be used as a mixin.
Expand Down Expand Up @@ -2438,6 +2484,83 @@ def from_pretrained(

return model

@staticmethod
def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
# This rename is logged.
if key.endswith("LayerNorm.beta"):
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
if key.endswith("LayerNorm.gamma"):
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True

return key, False

def _get_key_renaming_mapping(
self,
checkpoint_keys: List[str],
key_mapping: Optional[Dict[str, str]] = None,
loading_base_model_from_task_state_dict: bool = False,
loading_task_model_from_base_state_dict: bool = False,
):
"""
Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model
that we are loading expects. This is the single entry point for key renaming that will be used during
loading.
Log if any parameters have been renamed.
"""
prefix = self.base_model_prefix
_prefix = f"{prefix}."

renamed_keys = {}
key_renaming_mapping = {}
for key in checkpoint_keys:
# Class specific rename
new_key, has_changed = self._fix_state_dict_key_on_load(key)

# Optionally map the key according to `key_mapping`
if key_mapping is not None:
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
if n_replace > 0:
has_changed = True
break

# In this case, we need to add the prefix to the keys, to match them to the expected keys
if loading_task_model_from_base_state_dict:
new_key = ".".join([prefix, new_key])
key = ".".join([prefix, key])
# In this case we need to remove the prefix from the key to match them to the expected keys, and use
# only the keys starting with the prefix
elif loading_base_model_from_task_state_dict:
if not new_key.startswith(_prefix):
continue
new_key = new_key[len(_prefix) :]
key = key[len(_prefix) :]

if not has_changed:
key_renaming_mapping[new_key] = new_key
else:
key_renaming_mapping[key] = new_key

# track gamma/beta rename for logging
if has_changed:
if key.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key, new_key)
elif key.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key, new_key)

Comment on lines +2517 to +2553

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic in this loop is hard to follow and contains a bug. The loop variable key is modified during prefix handling, which causes the key_renaming_mapping to be populated with incorrect keys. This will lead to failures when the mapping is used to rename keys in the state dictionary. The logic for populating key_renaming_mapping is also unnecessarily complex and can be simplified.

        for key_from_checkpoint in checkpoint_keys:
            # Class specific rename
            new_key, has_changed = self._fix_state_dict_key_on_load(key_from_checkpoint)

            # Optionally map the key according to `key_mapping`
            if key_mapping is not None:
                for pattern, replacement in key_mapping.items():
                    new_key, n_replace = re.subn(pattern, replacement, new_key)
                    # Early exit of the loop
                    if n_replace > 0:
                        has_changed = True
                        break

            key_in_model = new_key
            # In this case, we need to add the prefix to the keys, to match them to the expected keys
            if loading_task_model_from_base_state_dict:
                key_in_model = ".".join([prefix, key_in_model])
            # In this case we need to remove the prefix from the key to match them to the expected keys, and use
            # only the keys starting with the prefix
            elif loading_base_model_from_task_state_dict:
                if not key_in_model.startswith(_prefix):
                    continue
                key_in_model = key_in_model[len(_prefix) :]

            key_renaming_mapping[key_from_checkpoint] = key_in_model

            # track gamma/beta rename for logging
            if has_changed:
                if key_from_checkpoint.endswith("LayerNorm.gamma"):
                    renamed_keys["LayerNorm.gamma"] = (key_from_checkpoint, new_key)
                elif key_from_checkpoint.endswith("LayerNorm.beta"):
                    renamed_keys["LayerNorm.beta"] = (key_from_checkpoint, new_key)

if renamed_keys:
warning_msg = f"A pretrained model of type `{self.__class__.__name__}` "
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.values():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)

return key_renaming_mapping

@classmethod
def _load_pretrained_model(
cls,
Expand All @@ -2450,71 +2573,79 @@ def _load_pretrained_model(
sharded_metadata=None,
dtype=None,
keep_in_fp32_modules=None,
key_mapping: Optional[Dict[str, str]] = None,
weights_only: bool = True,
):
model.tie_weights()

# Retrieve missing & unexpected_keys
model_state_dict = {k: v for k, v in model.parameters_and_names()}
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix
original_loaded_keys = loaded_keys

if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
# Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None:
original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"]
elif state_dict is not None:
original_checkpoint_keys = list(state_dict.keys())
else:
has_prefix_module = False
expects_prefix_module = False

# Mapping loaded_keys from pt to ms
pt2ms_mappings = _get_pt2ms_mappings(model)
loaded_keys = _get_pt2ms_mapped_k(pt2ms_mappings, has_prefix_module, expects_prefix_module, loaded_keys, prefix)

# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module

if remove_prefix_from_model:
_prefix = f"{prefix}."
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]

missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)

# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
original_checkpoint_keys = list(load_state_dict(pretrained_model_name_or_path).keys())

# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
_prefix = f"{prefix}."
has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False
expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module

# Find the key names that the model expects from the serialized keys
key_renaming_mapping = model._get_key_renaming_mapping(
original_checkpoint_keys,
key_mapping,
loading_base_model_from_task_state_dict,
loading_task_model_from_base_state_dict,
)
checkpoint_keys = list(key_renaming_mapping.values())

# Find missing and unexpected keys from the state dict
missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
cls,
model,
original_checkpoint_keys,
checkpoint_keys,
loading_base_model_from_task_state_dict,
)

# Set some modules to fp32 if any
if keep_in_fp32_modules is not None:
for name, param in model.parameters_and_names():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
param.set_dtype(ms.float32)

# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
# Make sure we are able to load base models as well as derived models (specific task models, with heads)
model_to_load = model
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
base_model_expected_keys = list(k for k, v in model_to_load.parameters_and_names())
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
# In this case, we load a ForTaskModel with keys from a BaseModel -> only load keys to the BaseModel
if loading_task_model_from_base_state_dict:
model_to_load = getattr(model, prefix)
# Here we need to remove the prefix we added to correctly find missing/unexpected keys, as we will load
# in the submodule
key_renaming_mapping = {k: v[len(_prefix) :] for k, v in key_renaming_mapping.items()}
checkpoint_keys = list(key_renaming_mapping.values())
# small sanity check: the base model should not contain task-specific head keys
task_specific_expected_keys = [s for s in model.state_dict().keys() if not s.startswith(_prefix)]
base_model_expected_keys = list(model_to_load.state_dict().keys())
if any(
key in task_specific_expected_keys and key not in base_model_expected_keys for key in checkpoint_keys
):
raise ValueError(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?"
)

# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."

Comment on lines +2643 to +2648

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code appears to be a remnant of the previous implementation and is now redundant. The logic for handling model prefixes is encapsulated within the new _get_key_renaming_mapping function. Furthermore, line 2645 model_to_load = model introduces a bug by resetting model_to_load, which might have been correctly set to a submodule earlier in the function. Since start_prefix is also no longer used, this entire block can be safely removed.

def _find_mismatched_keys(
state_dict,
model_state_dict,
Expand Down Expand Up @@ -2559,12 +2690,17 @@ def _find_mismatched_keys(
# Whole checkpoint
state_dict = _convert_state_dict(model, state_dict, prefix)

matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s]
if matching:
# Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}

mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
loading_task_model_from_base_state_dict,
loading_base_model_from_task_state_dict,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_sharded=False)
Expand All @@ -2586,14 +2722,21 @@ def _find_mismatched_keys(
state_dict = load_state_dict(shard_file)
state_dict = _convert_state_dict(model, state_dict, prefix)

matching = [s for s in key_renaming_mapping.keys() if "LayerNorm.gamma" in s]
if matching:
# Fix the key names when model weight names contain LayerNorm.gamma/LayerNorm.beta
state_dict = {
key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping
}

# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
loading_task_model_from_base_state_dict,
loading_base_model_from_task_state_dict,
ignore_mismatched_sizes,
)

Expand Down
Loading