Skip to content

Commit d5f992f

Browse files
inf3rnusgithub-actions[bot]Cyrilvallez
authored
Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag (#36835)
* Get parallel loader working. Include tests. * Update the tests for parallel loading * Rename env variables. * Add docs for parallel model weight loading. * Touch up parallel model loading docs. * Touch up parallel model loading docs again. * Edit comment in test_modeling_utils_parallel_loading.py * Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py * Correct times for parallelized loading, previous times were for a "hot" filesystem * Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule. * Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally. * Fix style on model loading parallelism changes. * Merge latest version of master's modeling_utils. * Removed unused variable. * Fix argument packing for the parallel loader. * Fix state dict being undefined in the parallel model loader. * Rename variables used in parallel model loading for clarity. Use get_module_from_name(). * Switch to the use of threads for parallel model loading. * Update docs for parallel loading. * Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting. * Move parallelized shard loading into its own function. * Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING. * Update copyright to 2025 in readme for paralell model loading. * Remove garbage collection line in load_shard_file, implicit garbage collection already occurs. * Run formatter on modeling_utils.py * Apply style fixes * Delete tests/utils/test_modeling_utils_parallel_loading.py --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Cyril Vallez <[email protected]>
1 parent 1ed1936 commit d5f992f

File tree

4 files changed

+234
-76
lines changed

4 files changed

+234
-76
lines changed

docs/source/en/_toctree.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,4 +1121,9 @@
11211121
- local: internal/time_series_utils
11221122
title: Utilities for Time Series
11231123
title: Internal helpers
1124+
- sections:
1125+
- local: reference/environment_variables
1126+
title: Environment Variables
1127+
title: Reference
11241128
title: API
1129+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Environment Variables
18+
19+
## HF_ENABLE_PARALLEL_LOADING
20+
21+
By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups around ~50%.
22+
23+
Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`.
24+
25+
e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~30s with this enabled vs ~55s without it.
26+
27+
Profile before committing to using this environment variable, this will not produce speed ups for smaller models.
28+
29+
```py
30+
import os
31+
32+
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"
33+
34+
from transformers import pipeline
35+
36+
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
37+
```
38+
39+
## HF_PARALLEL_LOADING_WORKERS
40+
41+
Determines how many threads should be used when parallel loading is enabled. Default is `8`.
42+
43+
If the number of files that are being loaded is less than the number of threads specified, the number that is actually spawned will be equal to the number of files.
44+
45+
e.g. If you specify 8 workers, and there are only 2 files, only 2 workers will be spawned.
46+
47+
Tune as you see fit.
48+
49+
```py
50+
import os
51+
52+
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"
53+
os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4"
54+
55+
from transformers import pipeline
56+
57+
model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
58+
```

src/transformers/modeling_utils.py

Lines changed: 150 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import tempfile
2828
import warnings
2929
from collections import defaultdict
30+
from concurrent.futures import ThreadPoolExecutor, as_completed
3031
from contextlib import contextmanager
3132
from dataclasses import dataclass
3233
from enum import Enum
@@ -870,6 +871,116 @@ def _load_state_dict_into_meta_model(
870871
return disk_offload_index, cpu_offload_index
871872

872873

874+
def load_shard_file(args):
875+
(
876+
shard_file,
877+
state_dict,
878+
disk_only_shard_files,
879+
is_hqq_or_bnb,
880+
is_quantized,
881+
device_map,
882+
hf_quantizer,
883+
key_renaming_mapping,
884+
weights_only,
885+
model_to_load,
886+
expected_keys,
887+
reverse_key_renaming_mapping,
888+
disk_offload_folder,
889+
disk_offload_index,
890+
cpu_offload_folder,
891+
cpu_offload_index,
892+
is_offloaded_safetensors,
893+
keep_in_fp32_regex,
894+
unexpected_keys,
895+
device_mesh,
896+
) = args
897+
898+
# Skip the load for shards that only contain disk-offloaded weights
899+
if shard_file in disk_only_shard_files:
900+
return [], disk_offload_index, cpu_offload_index
901+
902+
map_location = "cpu"
903+
if (
904+
shard_file.endswith(".safetensors")
905+
and not is_hqq_or_bnb
906+
and not (is_deepspeed_zero3_enabled() and not is_quantized)
907+
):
908+
map_location = "meta"
909+
elif (
910+
device_map is not None
911+
and hf_quantizer is not None
912+
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
913+
and (
914+
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
915+
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
916+
)
917+
):
918+
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
919+
920+
# If shard_file is "", we use the existing state_dict instead of loading it
921+
if shard_file != "":
922+
state_dict = load_state_dict(
923+
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
924+
)
925+
926+
# Fix the key names
927+
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
928+
929+
error_msgs = []
930+
931+
if is_deepspeed_zero3_enabled() and not is_quantized:
932+
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
933+
# Skip it with fsdp on ranks other than 0
934+
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
935+
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
936+
model_to_load,
937+
state_dict,
938+
shard_file,
939+
expected_keys,
940+
reverse_key_renaming_mapping,
941+
device_map=device_map,
942+
disk_offload_folder=disk_offload_folder,
943+
disk_offload_index=disk_offload_index,
944+
cpu_offload_folder=cpu_offload_folder,
945+
cpu_offload_index=cpu_offload_index,
946+
hf_quantizer=hf_quantizer,
947+
is_safetensors=is_offloaded_safetensors,
948+
keep_in_fp32_regex=keep_in_fp32_regex,
949+
unexpected_keys=unexpected_keys,
950+
device_mesh=device_mesh,
951+
)
952+
953+
return error_msgs, disk_offload_index, cpu_offload_index
954+
955+
956+
def load_shard_files_with_threadpool(args_list):
957+
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
958+
959+
# Do not spawn anymore workers than you need
960+
num_workers = min(len(args_list), num_workers)
961+
962+
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
963+
964+
error_msgs = []
965+
966+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
967+
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
968+
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
969+
for future in as_completed(futures):
970+
result = future.result()
971+
(
972+
_error_msgs,
973+
disk_offload_index,
974+
cpu_offload_index,
975+
) = result
976+
977+
error_msgs += _error_msgs
978+
979+
pbar.update(1)
980+
981+
return error_msgs, disk_offload_index, cpu_offload_index
982+
983+
873984
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
874985
if variant is not None:
875986
path, name = weights_name.rsplit(".", 1)
@@ -4973,9 +5084,6 @@ def _load_pretrained_model(
49735084
cpu_offload_folder = tempfile.mkdtemp()
49745085
cpu_offload_index = {}
49755086

4976-
# For nice tqdm bars
4977-
if checkpoint_files is not None and len(checkpoint_files) > 1:
4978-
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
49795087
# To be able to iterate, even if we don't use it if the state_dict is already provided
49805088
elif state_dict is not None:
49815089
checkpoint_files = [""]
@@ -4993,64 +5101,48 @@ def _load_pretrained_model(
49935101
expanded_device_map = expand_device_map(device_map, expected_keys)
49945102
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
49955103

4996-
error_msgs = []
4997-
# Iterate on all the shards to load the weights
4998-
for shard_file in checkpoint_files:
4999-
# Skip the load for shards that only contain disk-offloaded weights
5000-
if shard_file in disk_only_shard_files:
5001-
continue
5002-
5003-
map_location = "cpu"
5004-
if (
5005-
shard_file.endswith(".safetensors")
5006-
and not is_hqq_or_bnb
5007-
and not (is_deepspeed_zero3_enabled() and not is_quantized)
5008-
):
5009-
map_location = "meta"
5010-
elif (
5011-
device_map is not None
5012-
and hf_quantizer is not None
5013-
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
5014-
and (
5015-
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
5016-
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
5017-
)
5018-
):
5019-
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
5020-
5021-
# If shard_file is "", we use the existing state_dict instead of loading it
5022-
if shard_file != "":
5023-
state_dict = load_state_dict(
5024-
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
5025-
)
5104+
# Prepare and compatabilize arguments for serial and parallel shard loading
5105+
args_list = [
5106+
(
5107+
shard_file,
5108+
state_dict,
5109+
disk_only_shard_files,
5110+
is_hqq_or_bnb,
5111+
is_quantized,
5112+
device_map,
5113+
hf_quantizer,
5114+
key_renaming_mapping,
5115+
weights_only,
5116+
model_to_load,
5117+
expected_keys,
5118+
reverse_key_renaming_mapping,
5119+
disk_offload_folder,
5120+
disk_offload_index,
5121+
cpu_offload_folder,
5122+
cpu_offload_index,
5123+
is_offloaded_safetensors,
5124+
keep_in_fp32_regex,
5125+
unexpected_keys,
5126+
device_mesh,
5127+
)
5128+
for shard_file in checkpoint_files
5129+
]
50265130

5027-
# Fix the key names
5028-
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
5131+
error_msgs = []
50295132

5030-
if is_deepspeed_zero3_enabled() and not is_quantized:
5031-
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
5032-
# Skip it with fsdp on ranks other than 0
5033-
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
5034-
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
5035-
model_to_load,
5036-
state_dict,
5037-
shard_file,
5038-
expected_keys,
5039-
reverse_key_renaming_mapping,
5040-
device_map=device_map,
5041-
disk_offload_folder=disk_offload_folder,
5042-
disk_offload_index=disk_offload_index,
5043-
cpu_offload_folder=cpu_offload_folder,
5044-
cpu_offload_index=cpu_offload_index,
5045-
hf_quantizer=hf_quantizer,
5046-
is_safetensors=is_offloaded_safetensors,
5047-
keep_in_fp32_regex=keep_in_fp32_regex,
5048-
unexpected_keys=unexpected_keys,
5049-
device_mesh=device_mesh,
5050-
)
5133+
if (
5134+
os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
5135+
and not is_deepspeed_zero3_enabled()
5136+
):
5137+
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list)
5138+
error_msgs += _error_msgs
5139+
else:
5140+
if len(args_list) > 1:
5141+
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
50515142

5052-
# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
5053-
del state_dict
5143+
for args in args_list:
5144+
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args)
5145+
error_msgs += _error_msgs
50545146

50555147
# Adjust offloaded weights name and save if needed
50565148
if disk_offload_index is not None and len(disk_offload_index) > 0:

tests/utils/test_modeling_utils.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,27 @@ def test_local_files_only(self):
297297
hub.TRANSFORMERS_CACHE = transformers_cache
298298

299299

300+
# Need to be serializable, which means they cannot be in a test class method
301+
class TestGammaBetaNorm(torch.nn.Module):
302+
def __init__(self):
303+
super().__init__()
304+
self.gamma = torch.nn.Parameter(torch.ones(1))
305+
self.beta = torch.nn.Parameter(torch.zeros(1))
306+
307+
def forward(self):
308+
return self.gamma.sum() + self.beta.sum()
309+
310+
311+
class TestModelGammaBeta(PreTrainedModel):
312+
def __init__(self, config):
313+
super().__init__(config)
314+
self.LayerNorm = TestGammaBetaNorm()
315+
self.post_init()
316+
317+
def forward(self):
318+
return self.LayerNorm()
319+
320+
300321
if is_flax_available():
301322
from transformers import FlaxBertModel
302323

@@ -1636,24 +1657,6 @@ def test_model_from_pretrained_from_mlx(self):
16361657
torch.testing.assert_close(outputs_from_saved["logits"], outputs["logits"])
16371658

16381659
def test_warning_for_beta_gamma_parameters(self):
1639-
class TestGammaBetaNorm(torch.nn.Module):
1640-
def __init__(self):
1641-
super().__init__()
1642-
self.gamma = torch.nn.Parameter(torch.ones(1))
1643-
self.beta = torch.nn.Parameter(torch.zeros(1))
1644-
1645-
def forward(self):
1646-
return self.gamma.sum() + self.beta.sum()
1647-
1648-
class TestModelGammaBeta(PreTrainedModel):
1649-
def __init__(self, config):
1650-
super().__init__(config)
1651-
self.LayerNorm = TestGammaBetaNorm()
1652-
self.post_init()
1653-
1654-
def forward(self):
1655-
return self.LayerNorm()
1656-
16571660
logger = logging.get_logger("transformers.modeling_utils")
16581661
config = PretrainedConfig()
16591662
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"

0 commit comments

Comments
 (0)