66from airunner .components .application .exceptions import InterruptedException
77from airunner .utils .memory import clear_memory
88from airunner .utils .settings .get_qsettings import get_qsettings
9+ from airunner .components .art .schedulers .flow_match_scheduler_factory import (
10+ is_flow_match_scheduler ,
11+ create_flow_match_scheduler ,
12+ )
13+ from airunner .enums import Scheduler , ModelType , ModelStatus
14+
15+ from accelerate .hooks import remove_hook_from_module
916
1017
1118def _aggressive_memory_cleanup ():
@@ -57,7 +64,6 @@ def _prepare_data(self, active_rect=None) -> Dict:
5764 self ._strip_zimage_incompatible_params (data )
5865 self ._enforce_zimage_guidance (data )
5966 data ["max_sequence_length" ] = 512
60- self ._log_zimage_generation_params (data )
6167 return data
6268
6369 def _strip_zimage_incompatible_params (self , data : Dict ) -> None :
@@ -81,34 +87,40 @@ def _enforce_zimage_guidance(self, data: Dict) -> None:
8187 """
8288 pass
8389
84- def _log_zimage_generation_params (self , data : Dict ) -> None :
85- """Log core generation parameters for debugging."""
86- debug_fields = {
87- "prompt" : data .get ("prompt" , "MISSING!" )[:50 ] + "..." ,
88- "guidance_scale" : data .get ("guidance_scale" , "MISSING!" ),
89- "steps" : data .get ("num_inference_steps" , "MISSING!" ),
90- "size" : f"{ data .get ('width' )} x{ data .get ('height' )} " ,
91- "max_sequence_length" : data .get ("max_sequence_length" , "MISSING!" ),
92- }
93- self .logger .info (
94- "[Z-IMAGE DEBUG] Keys: %s | Values: %s" ,
95- list (data .keys ()),
96- debug_fields ,
97- )
98-
9990 def _unload_loras (self ):
10091 """Unload Z-Image LoRA weights if any are loaded.
10192
102- Z-Image supports LoRA weights through ZImageLoraLoaderMixin .
93+ Z-Image uses additive LoRA that can be removed without model reload .
10394 """
104- if hasattr (self ._pipe , 'unload_lora_weights' ):
95+ self .logger .debug ("Unloading Z-Image LoRA weights" )
96+ if self ._pipe is not None and hasattr (self ._pipe , 'unload_lora_weights' ):
10597 try :
10698 self ._pipe .unload_lora_weights ()
107- self .logger .debug ( " Unloaded Z-Image LoRA weights" )
99+ self .logger .info ( "✓ Unloaded all Z-Image LoRA weights" )
108100 except Exception as e :
109- self .logger .debug (f"No LoRA weights to unload : { e } " )
101+ self .logger .warning (f"Error unloading LoRA weights: { e } " )
110102 self ._loaded_lora = {}
111103 self ._disabled_lora = []
104+
105+ def _disable_lora (self , adapter_name : str ):
106+ """Disable a specific LoRA adapter without removing it.
107+
108+ Args:
109+ adapter_name: Name of the adapter to disable
110+ """
111+ if self ._pipe is not None and hasattr (self ._pipe , 'set_lora_enabled' ):
112+ self ._pipe .set_lora_enabled (adapter_name , False )
113+ self .logger .debug (f"Disabled LoRA: { adapter_name } " )
114+
115+ def _enable_lora (self , adapter_name : str ):
116+ """Enable a specific LoRA adapter.
117+
118+ Args:
119+ adapter_name: Name of the adapter to enable
120+ """
121+ if self ._pipe is not None and hasattr (self ._pipe , 'set_lora_enabled' ):
122+ self ._pipe .set_lora_enabled (adapter_name , True )
123+ self .logger .debug (f"Enabled LoRA: { adapter_name } " )
112124
113125 def _load_scheduler (self , scheduler_name = None ):
114126 """Load a flow-match scheduler for Z-Image.
@@ -118,11 +130,7 @@ def _load_scheduler(self, scheduler_name=None):
118130 Args:
119131 scheduler_name: Display name of the scheduler to load.
120132 """
121- from airunner .components .art .schedulers .flow_match_scheduler_factory import (
122- is_flow_match_scheduler ,
123- create_flow_match_scheduler ,
124- )
125- from airunner .enums import Scheduler , ModelType , ModelStatus
133+ # imports moved to module level for performance and clarity
126134
127135 # Get scheduler name
128136 requested_name = (
@@ -232,11 +240,8 @@ def _unload_pipe(self):
232240 return
233241
234242 # Import accelerate hooks removal if available
235- try :
236- from accelerate .hooks import remove_hook_from_module
237- has_accelerate_hooks = True
238- except ImportError :
239- has_accelerate_hooks = False
243+ has_accelerate_hooks = remove_hook_from_module is not None
244+ if not has_accelerate_hooks :
240245 self .logger .debug ("accelerate.hooks not available, using manual cleanup" )
241246
242247 # List of all Z-Image components to clean up (ordered by size)
@@ -259,78 +264,12 @@ def _unload_pipe(self):
259264 self .logger .debug (f"Error removing hook: { e } " )
260265 self ._pipe ._all_hooks .clear ()
261266
262- # Process each component
267+ # Process each component via helper
263268 for component_name in component_names :
264269 component = getattr (self ._pipe , component_name , None )
265270 if component is None :
266271 continue
267-
268- self .logger .debug (f"Cleaning up { component_name } ..." )
269-
270- # 1. Remove accelerate hooks using official API
271- if has_accelerate_hooks :
272- try :
273- remove_hook_from_module (component , recurse = True )
274- self .logger .debug (f"Removed hooks from { component_name } " )
275- except Exception as e :
276- self .logger .debug (f"Hook removal for { component_name } : { e } " )
277-
278- # 2. Manual hook cleanup for any remaining hooks
279- if hasattr (component , "_hf_hook" ):
280- try :
281- hook = component ._hf_hook
282- # Clear any offloaded weights first
283- if hasattr (hook , "weights_map" ) and hook .weights_map is not None :
284- hook .weights_map .clear ()
285- if hasattr (hook , "offload" ):
286- try :
287- hook .offload (component )
288- except Exception :
289- pass
290- delattr (component , "_hf_hook" )
291- except Exception as e :
292- self .logger .debug (f"Manual hook cleanup for { component_name } : { e } " )
293-
294- # 3. For models with device_map (quantized models), clear the device_map state
295- if hasattr (component , "hf_device_map" ):
296- try :
297- component .hf_device_map = None
298- except Exception :
299- pass
300-
301- # 4. CRITICAL: Do NOT move to CPU - this creates new tensors in CPU RAM
302- # Instead, delete tensors in-place on their current device
303-
304- # 5. Clear parameter data to free memory
305- if hasattr (component , "parameters" ):
306- try :
307- for param in component .parameters ():
308- param .data = torch .empty (0 , device = param .device )
309- if param .grad is not None :
310- param .grad = None
311- except Exception as e :
312- self .logger .debug (f"Error clearing params for { component_name } : { e } " )
313-
314- # 6. Detach the component from the pipeline
315- try :
316- setattr (self ._pipe , component_name , None )
317- except Exception :
318- pass
319-
320- # 7. Delete component reference
321- try :
322- del component
323- except Exception :
324- pass
325-
326- # 8. Run gc after each large component to free memory immediately
327- if component_name in ("text_encoder" , "transformer" ):
328- _aggressive_memory_cleanup ()
329-
330- # Clear execution device reference
331- if hasattr (self ._pipe , "_execution_device" ):
332- self ._pipe ._execution_device = None
333-
272+ self ._cleanup_pipeline_component (component_name , component , has_accelerate_hooks )
334273 except Exception as e :
335274 self .logger .warning (f"Error during hook removal: { e } " )
336275
@@ -347,20 +286,46 @@ def _unload_pipe(self):
347286
348287 self .logger .info ("✓ Z-Image pipeline unloaded and memory freed" )
349288
350- def _clear_pipeline_caches (self ):
351- """Clear internal pipeline caches to free RAM and VRAM.
289+ def _cleanup_pipeline_component (self , component_name : str , component : Any , has_accelerate_hooks : bool ) -> None :
290+ """Cleanup a component attached to the pipeline.
291+
292+ Extraction from the prior implementation to reduce _unload_pipe size.
293+ """
294+ self .logger .debug (f"Cleaning up { component_name } ..." )
295+
296+ # 1. Remove accelerate hooks using official API
297+ if has_accelerate_hooks :
298+ try :
299+ remove_hook_from_module (component , recurse = True )
300+ self .logger .debug (f"Removed hooks from { component_name } " )
301+ except Exception as e :
302+ self .logger .debug (f"Hook removal for { component_name } : { e } " )
303+
304+ # 2. Manual hook cleanup for any remaining hooks
305+ if hasattr (component , "_hf_hook" ):
306+ try :
307+ hook = component ._hf_hook
308+ if hasattr (hook , "weights_map" ) and hook .weights_map is not None :
309+ hook .weights_map .clear ()
310+ if hasattr (hook , "offload" ):
311+ try :
312+ hook .offload (component )
313+ except Exception :
314+ pass
315+ delattr (component , "_hf_hook" )
316+ except Exception as e :
317+ self .logger .debug (f"Manual hook cleanup for { component_name } : { e } " )
352318
319+ def _clear_pipeline_caches (self ):
320+ """Clear cached tensors and per-component caches on the active pipeline.
321+
353322 This is called after each generation to prevent memory accumulation.
354323 """
355324 if self ._pipe is None :
356325 return
357326
358327 self .logger .debug ("Clearing pipeline caches to free RAM" )
359328
360- # Clear any cached tensors on the pipeline
361- if hasattr (self ._pipe , "_callback_tensor_inputs" ):
362- self ._pipe ._callback_tensor_inputs = None
363-
364329 # For text encoder, clear any cached key/values
365330 text_encoder = getattr (self ._pipe , "text_encoder" , None )
366331 if text_encoder is not None and hasattr (text_encoder , "past_key_values" ):
0 commit comments