Skip to content
Merged
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ def _assign_components_to_devices(


def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
# TODO: seperate out different device_map methods when it gets to it.
if device_map != "balanced":
return device_map
# To avoid circular import problem.
from diffusers import pipelines

Expand Down
18 changes: 11 additions & 7 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)

SUPPORTED_DEVICE_MAP = ["balanced"]
# TODO: support single-device namings
SUPPORTED_DEVICE_MAP = ["balanced", "cuda"]

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -988,12 +989,15 @@ def load_module(name, value):
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans
if final_device_map is not None and len(final_device_map) > 0:
component_device = final_device_map.get(name, None)
if component_device is not None:
current_device_map = {"": component_device}
else:
current_device_map = None
if final_device_map is not None:
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
component_device = final_device_map.get(name, None)
if component_device is not None:
current_device_map = {"": component_device}
else:
current_device_map = None
elif isinstance(final_device_map, str):
current_device_map = final_device_map

# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
Expand Down