Skip to content

[docs] Parallel loading of shards #12135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/en/using-diffusers/loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,30 @@ print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)

If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.

### Parallel loading

Large models are often [sharded](../training/distributed_inference#model-sharding) into smaller files so that they are easier to load. Diffusers supports loading shards in parallel to speed up the loading process.

Set the environment variables below to enable parallel loading.

- Set `HF_ENABLE_PARALLEL_LOADING` to `"YES"` to enable parallel loading of shards.
- Set `HF_PARALLEL_LOADING_WORKERS` to configure the number of parallel threads to use when loading shards. More workers loads a model faster but uses more memory.

The `device_map` argument should be set to `"cuda"` to pre-allocate a large chunk of memory based on the model size. This substantially reduces model load time because warming up the memory allocator now avoids many smaller calls to the allocator later.

```py
import os
import torch
from diffusers import DiffusionPipeline

os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
pipeline = DiffusionPipeline.from_pretrained(
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
torch_dtype=torch.bfloat16,
device_map="cuda"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to talk a bit about device_map? The motivation behind passing device_map essentially comes from this PR: #11904. I have also tried providing a bit of motivation behind adding it at the pipeline-level here: #12122 (comment).

)
```

### Local pipeline

To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.
Expand Down