Skip to content

feat: cuda device_map for pipelines. #12122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 14, 2025
Merged
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ def _assign_components_to_devices(


def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
# TODO: seperate out different device_map methods when it gets to it.
if device_map != "balanced":
return device_map
# To avoid circular import problem.
from diffusers import pipelines

Expand Down
17 changes: 10 additions & 7 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)

SUPPORTED_DEVICE_MAP = ["balanced"]
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -988,12 +988,15 @@ def load_module(name, value):
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans
if final_device_map is not None and len(final_device_map) > 0:
component_device = final_device_map.get(name, None)
if component_device is not None:
current_device_map = {"": component_device}
else:
current_device_map = None
if final_device_map is not None:
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
component_device = final_device_map.get(name, None)
if component_device is not None:
current_device_map = {"": component_device}
else:
current_device_map = None
elif isinstance(final_device_map, str):
current_device_map = final_device_map

# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PyTorch utilities: Utilities related to PyTorch
"""

import functools
from typing import List, Optional, Tuple, Union

from . import logging
Expand Down Expand Up @@ -168,6 +169,7 @@ def get_torch_cuda_device_capability():
return None


@functools.lru_cache
def get_device():
if torch.cuda.is_available():
return "cuda"
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2339,6 +2339,29 @@ def test_torch_dtype_dict(self):
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)

@require_torch_accelerator
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = torch.manual_seed(0)
out = pipe(**inputs)[0]

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device)
for component in loaded_pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
inputs["generator"] = torch.manual_seed(0)
loaded_out = loaded_pipe(**inputs)[0]
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
self.assertLess(max_diff, expected_max_difference)


@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
Expand Down