Skip to content

Commit 421ee07

Browse files
stevhliusayakpaul
andauthored
[docs] Parallel loading of shards (#12135)
* initial * feedback * Update docs/source/en/using-diffusers/loading.md --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 123506e commit 421ee07

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

docs/source/en/using-diffusers/loading.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,30 @@ print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
112112

113113
If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.
114114

115+
### Parallel loading
116+
117+
Large models are often [sharded](../training/distributed_inference#model-sharding) into smaller files so that they are easier to load. Diffusers supports loading shards in parallel to speed up the loading process.
118+
119+
Set the environment variables below to enable parallel loading.
120+
121+
- Set `HF_ENABLE_PARALLEL_LOADING` to `"YES"` to enable parallel loading of shards.
122+
- Set `HF_PARALLEL_LOADING_WORKERS` to configure the number of parallel threads to use when loading shards. More workers loads a model faster but uses more memory.
123+
124+
The `device_map` argument should be set to `"cuda"` to pre-allocate a large chunk of memory based on the model size. This substantially reduces model load time because warming up the memory allocator now avoids many smaller calls to the allocator later.
125+
126+
```py
127+
import os
128+
import torch
129+
from diffusers import DiffusionPipeline
130+
131+
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
132+
pipeline = DiffusionPipeline.from_pretrained(
133+
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
134+
torch_dtype=torch.bfloat16,
135+
device_map="cuda"
136+
)
137+
```
138+
115139
### Local pipeline
116140

117141
To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.

0 commit comments

Comments
 (0)