Skip to content

feat: cuda device_map for pipelines. #12122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 11, 2025

What does this PR do?

TL;DR: This PR adds a device_map option at the pipeline-level to speed up end-to-end pipeline loading on a target device.

To benefit from #11904, users have to follow this pattern:

from diffusers import AutoModel, DiffusionPipeline
import torch 

# first initialize the model on device_map to benefit from fast loading
model = AutoModel.from_pretrained(..., device_map="cuda", torch_dtype=...)

# initialize the pipeline
pipe = DiffusionPipeline.from_pretrained(..., transformer=model, torch_dtype=...)

# place other stuff on cuda
for name, component in pipe.components.items():
     if name != "model_already_loaded_above":
         component.to("cuda")

# run inference
...

We could improve the UX a bit by letting the users pass a device_map="cuda" (or whatever valid value) WHILE initializing the pipe. This PR tackles that =>

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(..., device_map="cuda", torch_dtype=...)
...

For pipelines like Flux, passing device_map for loading the text encoders might be as same as doing to(). However, for pipelines like Qwen-Image, that use a mid-range model like Qwen25VL-7B, passing device_map="cuda" while initializing the pipeline should be beneficial (of course, the target device should have enough VRAM to support this). Below are the results I got for Qwen-Image (with cold cache):

time: 8.494s (no device_map)
time: 6.678s (device_map)
Code
import time
t_ini = time.time()

import torch
from diffusers import DiffusionPipeline
print(f"import time: {time.time() - t_ini:.3f}s")

model_id = "Qwen/Qwen-Image"

t0 = time.time()
torch.cuda.synchronize()
print(f"CUDA sync time: {time.time() - t0:.3f}s")

print("starting pipe load")
t1 = time.time()
pipe = DiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="cuda"
)
torch.cuda.synchronize()
t2 = time.time()

diff = t2 - t1
print(f"time: {diff:.3f}s")

print(getattr(pipe, "hf_device_map", None))
_ = pipe("dog", num_inference_steps=2)

Any objections?

@sayakpaul sayakpaul requested review from SunMarc and a-r-r-o-w August 11, 2025 06:35
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Could we add a simple fast GPU test

@sayakpaul sayakpaul changed the title [wip] feat: cuda device_map for pipelines. feat: cuda device_map for pipelines. Aug 12, 2025
@sayakpaul sayakpaul requested a review from DN6 August 12, 2025 16:16
@sayakpaul sayakpaul marked this pull request as ready for review August 12, 2025 16:16
@sayakpaul
Copy link
Member Author

@DN6 done! I have added a test, too.

@@ -67,6 +67,7 @@
numpy_to_pil,
)
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.testing_utils import torch_device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not to import from testing_utils for non-test modules (in fact we should move this module out of src). It would be better to redefine the relevant torch_device functionality in torch_utils and import from there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah my thoughts too. We should actually move torch_device to utils.

@sayakpaul sayakpaul requested a review from DN6 August 13, 2025 08:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants