Skip to content

Commit 638cc03

Browse files
authored
[Modular] update the collection behavior (#11963)
* only remove from the collection
1 parent 9db9be6 commit 638cc03

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def add(self, name: str, component: Any, collection: Optional[str] = None):
386386
id(component) is Python's built-in unique identifier for the object
387387
"""
388388
component_id = f"{name}_{id(component)}"
389+
is_new_component = True
389390

390391
# check for duplicated components
391392
for comp_id, comp in self.components.items():
@@ -394,6 +395,7 @@ def add(self, name: str, component: Any, collection: Optional[str] = None):
394395
if comp_name == name:
395396
logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
396397
component_id = comp_id
398+
is_new_component = False
397399
break
398400
else:
399401
logger.warning(
@@ -426,19 +428,39 @@ def add(self, name: str, component: Any, collection: Optional[str] = None):
426428
logger.warning(
427429
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
428430
)
429-
self.remove(comp_id)
431+
# remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager)
432+
self.remove_from_collection(comp_id, collection)
433+
430434
self.collections[collection].add(component_id)
431435
logger.info(
432436
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
433437
)
434438
else:
435439
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
436440

437-
if self._auto_offload_enabled:
441+
if self._auto_offload_enabled and is_new_component:
438442
self.enable_auto_cpu_offload(self._auto_offload_device)
439443

440444
return component_id
441445

446+
def remove_from_collection(self, component_id: str, collection: str):
447+
"""
448+
Remove a component from a collection.
449+
"""
450+
if collection not in self.collections:
451+
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
452+
return
453+
if component_id not in self.collections[collection]:
454+
logger.warning(f"Component '{component_id}' not found in collection '{collection}'")
455+
return
456+
# remove from the collection
457+
self.collections[collection].remove(component_id)
458+
# check if this component is in any other collection
459+
comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps]
460+
if not comp_colls: # only if no other collection contains this component, remove it
461+
logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager")
462+
self.remove(component_id)
463+
442464
def remove(self, component_id: str = None):
443465
"""
444466
Remove a component from the ComponentsManager.

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def load_id(self) -> str:
185185
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
186186
segments).
187187
"""
188+
if self.default_creation_method == "from_config":
189+
return "null"
188190
parts = [getattr(self, k) for k in self.loading_fields()]
189191
parts = ["null" if p is None else p for p in parts]
190192
return "|".join(p for p in parts if p)

0 commit comments

Comments
 (0)