-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[core] parallel loading of shards #12028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
af72ece
d4e2976
c9b680d
ab84d5a
536df5a
04cd5cc
cb0b3ed
2fdc091
6d15594
d34f426
35e859b
2cc83b8
9844c10
73fb972
04bff1c
cd13977
8968e2f
dca6388
e276f08
ad2dd62
36c86d2
ae2561b
f0eec0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import os | ||
from array import array | ||
from collections import OrderedDict, defaultdict | ||
from concurrent.futures import ThreadPoolExecutor, as_completed | ||
from pathlib import Path | ||
from typing import Dict, List, Optional, Union | ||
from zipfile import is_zipfile | ||
|
@@ -310,6 +311,130 @@ def load_model_dict_into_meta( | |
return offload_index, state_dict_index | ||
|
||
|
||
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): | ||
""" | ||
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first | ||
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's | ||
parameters. | ||
|
||
""" | ||
if model_to_load.device.type == "meta": | ||
return False | ||
|
||
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: | ||
return False | ||
|
||
# Some models explicitly do not support param buffer assignment | ||
if not getattr(model_to_load, "_supports_param_buffer_assignment", True): | ||
logger.debug( | ||
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" | ||
) | ||
return False | ||
|
||
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype | ||
first_key = next(iter(model_to_load.state_dict().keys())) | ||
if start_prefix + first_key in state_dict: | ||
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype | ||
|
||
return False | ||
|
||
|
||
def load_shard_file(args): | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
( | ||
model, | ||
model_state_dict, | ||
shard_file, | ||
device_map, | ||
dtype, | ||
hf_quantizer, | ||
keep_in_fp32_modules, | ||
dduf_entries, | ||
loaded_keys, | ||
unexpected_keys, | ||
offload_index, | ||
offload_folder, | ||
state_dict_index, | ||
state_dict_folder, | ||
ignore_mismatched_sizes, | ||
low_cpu_mem_usage, | ||
) = args | ||
assign_to_params_buffers = None | ||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) | ||
mismatched_keys = _find_mismatched_keys( | ||
state_dict, | ||
model_state_dict, | ||
loaded_keys, | ||
ignore_mismatched_sizes, | ||
) | ||
error_msgs = [] | ||
if low_cpu_mem_usage: | ||
offload_index, state_dict_index = load_model_dict_into_meta( | ||
model, | ||
state_dict, | ||
device_map=device_map, | ||
dtype=dtype, | ||
hf_quantizer=hf_quantizer, | ||
keep_in_fp32_modules=keep_in_fp32_modules, | ||
unexpected_keys=unexpected_keys, | ||
offload_folder=offload_folder, | ||
offload_index=offload_index, | ||
state_dict_index=state_dict_index, | ||
state_dict_folder=state_dict_folder, | ||
) | ||
else: | ||
if assign_to_params_buffers is None: | ||
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) | ||
|
||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) | ||
return offload_index, state_dict_index, mismatched_keys, error_msgs | ||
|
||
|
||
def load_shard_files_with_threadpool(args_list): | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would add |
||
|
||
# Do not spawn anymore workers than you need | ||
num_workers = min(len(args_list), num_workers) | ||
|
||
logger.info(f"Loading model weights in parallel with {num_workers} workers...") | ||
|
||
error_msgs = [] | ||
mismatched_keys = [] | ||
|
||
with ThreadPoolExecutor(max_workers=num_workers) as executor: | ||
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: | ||
futures = [executor.submit(load_shard_file, arg) for arg in args_list] | ||
for future in as_completed(futures): | ||
result = future.result() | ||
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result | ||
error_msgs += _error_msgs | ||
mismatched_keys += _mismatched_keys | ||
pbar.update(1) | ||
|
||
return offload_index, state_dict_index, mismatched_keys, error_msgs | ||
|
||
|
||
def _find_mismatched_keys( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. Moved it out of |
||
state_dict, | ||
model_state_dict, | ||
loaded_keys, | ||
ignore_mismatched_sizes, | ||
): | ||
mismatched_keys = [] | ||
if ignore_mismatched_sizes: | ||
for checkpoint_key in loaded_keys: | ||
model_key = checkpoint_key | ||
# If the checkpoint is sharded, we may not have the key here. | ||
if checkpoint_key not in state_dict: | ||
continue | ||
|
||
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: | ||
mismatched_keys.append( | ||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) | ||
) | ||
del state_dict[checkpoint_key] | ||
return mismatched_keys | ||
|
||
|
||
def _load_state_dict_into_model( | ||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False | ||
) -> List[str]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ | |
from ..quantizers.quantization_config import QuantizationMethod | ||
from ..utils import ( | ||
CONFIG_NAME, | ||
ENV_VARS_TRUE_VALUES, | ||
FLAX_WEIGHTS_NAME, | ||
SAFE_WEIGHTS_INDEX_NAME, | ||
SAFETENSORS_WEIGHTS_NAME, | ||
|
@@ -69,9 +70,8 @@ | |
_expand_device_map, | ||
_fetch_index_file, | ||
_fetch_index_file_legacy, | ||
_find_mismatched_keys, | ||
_load_state_dict_into_model, | ||
load_model_dict_into_meta, | ||
load_shard_file, | ||
load_shard_files_with_threadpool, | ||
load_state_dict, | ||
) | ||
|
||
|
@@ -208,34 +208,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | |
return last_tuple[1].dtype | ||
|
||
|
||
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): | ||
""" | ||
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first | ||
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's | ||
parameters. | ||
|
||
""" | ||
if model_to_load.device.type == "meta": | ||
return False | ||
|
||
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: | ||
return False | ||
|
||
# Some models explicitly do not support param buffer assignment | ||
if not getattr(model_to_load, "_supports_param_buffer_assignment", True): | ||
logger.debug( | ||
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" | ||
) | ||
return False | ||
|
||
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype | ||
first_key = next(iter(model_to_load.state_dict().keys())) | ||
if start_prefix + first_key in state_dict: | ||
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype | ||
|
||
return False | ||
|
||
|
||
@contextmanager | ||
def no_init_weights(): | ||
""" | ||
|
@@ -988,6 +960,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) | ||
disable_mmap = kwargs.pop("disable_mmap", False) | ||
|
||
is_parallel_loading_enabled = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES | ||
if is_parallel_loading_enabled and not low_cpu_mem_usage: | ||
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.") | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): | ||
torch_dtype = torch.float32 | ||
logger.warning( | ||
|
@@ -1323,6 +1299,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
hf_quantizer=hf_quantizer, | ||
keep_in_fp32_modules=keep_in_fp32_modules, | ||
dduf_entries=dduf_entries, | ||
is_parallel_loading_enabled=is_parallel_loading_enabled, | ||
) | ||
loading_info = { | ||
"missing_keys": missing_keys, | ||
|
@@ -1518,6 +1495,7 @@ def _load_pretrained_model( | |
offload_state_dict: Optional[bool] = None, | ||
offload_folder: Optional[Union[str, os.PathLike]] = None, | ||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None, | ||
is_parallel_loading_enabled: Optional[bool] = False, | ||
): | ||
model_state_dict = model.state_dict() | ||
expected_keys = list(model_state_dict.keys()) | ||
|
@@ -1531,6 +1509,9 @@ def _load_pretrained_model( | |
for pat in cls._keys_to_ignore_on_load_unexpected: | ||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | ||
|
||
mismatched_keys = [] | ||
error_msgs = [] | ||
|
||
# Deal with offload | ||
if device_map is not None and "disk" in device_map.values(): | ||
if offload_folder is None: | ||
|
@@ -1566,37 +1547,43 @@ def _load_pretrained_model( | |
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also | ||
resolved_model_file = [state_dict] | ||
|
||
if len(resolved_model_file) > 1: | ||
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") | ||
|
||
mismatched_keys = [] | ||
assign_to_params_buffers = None | ||
error_msgs = [] | ||
# prepare the arguments. | ||
args_list = [ | ||
( | ||
model, | ||
model_state_dict, | ||
shard_file, | ||
device_map, | ||
dtype, | ||
hf_quantizer, | ||
keep_in_fp32_modules, | ||
dduf_entries, | ||
loaded_keys, | ||
unexpected_keys, | ||
offload_index, | ||
offload_folder, | ||
state_dict_index, | ||
state_dict_folder, | ||
ignore_mismatched_sizes, | ||
low_cpu_mem_usage, | ||
) | ||
for shard_file in resolved_model_file | ||
] | ||
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the same arguments are used across the two loading functions, it's a good candidate for load_fn = partial(
load_shard_files_with_threadpool if is_parallel_loading_enabled else load_shard_file,
model=model,
model_state_dict=model_state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
loaded_keys=loaded_keys,
unexpected_keys=unexpected_keys,
offload_index=offload_index,
offload_folder=offload_folder,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if is_parallel_loading_enabled:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(
resolved_model_file,
)
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys
else:
shard_files = resolved_model_file
if len(resolved_model_file) > 1:
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
for shard_file in resolved_model_file:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys |
||
|
||
for shard_file in resolved_model_file: | ||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) | ||
mismatched_keys += _find_mismatched_keys( | ||
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes | ||
if is_parallel_loading_enabled: | ||
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool( | ||
args_list | ||
) | ||
error_msgs += _error_msgs | ||
mismatched_keys += _mismatched_keys | ||
else: | ||
if len(args_list) > 1: | ||
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") | ||
|
||
if low_cpu_mem_usage: | ||
offload_index, state_dict_index = load_model_dict_into_meta( | ||
model, | ||
state_dict, | ||
device_map=device_map, | ||
dtype=dtype, | ||
hf_quantizer=hf_quantizer, | ||
keep_in_fp32_modules=keep_in_fp32_modules, | ||
unexpected_keys=unexpected_keys, | ||
offload_folder=offload_folder, | ||
offload_index=offload_index, | ||
state_dict_index=state_dict_index, | ||
state_dict_folder=state_dict_folder, | ||
) | ||
else: | ||
if assign_to_params_buffers is None: | ||
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) | ||
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) | ||
for args in args_list: | ||
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_file(args) | ||
error_msgs += _error_msgs | ||
mismatched_keys += _mismatched_keys | ||
|
||
empty_device_cache() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved it here from
modeling_utils.py
.