Skip to content

Commit a8e14e9

Browse files
authored
Merge pull request #1962 from Capsize-Games/develop
improve zimage and add pixtral support
2 parents 0a7f43c + 87a77da commit a8e14e9

33 files changed

+2813
-1013
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@
187187

188188
setup(
189189
name="airunner",
190-
version="5.4.2",
190+
version="5.5.0",
191191
author="Capsize LLC",
192192
description="Run local opensource AI models (Stable Diffusion, LLMs, TTS, STT, chatbots) in a lightweight Python GUI",
193193
long_description=open("README.md", "r", encoding="utf-8").read(),

src/airunner/components/art/gui/widgets/lora/lora_container_widget.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class LoraContainerWidget(BaseWidget):
2323

2424
def __init__(self, *args, **kwargs):
2525
self.signal_handlers = {
26-
SignalCode.APPLICATION_SETTINGS_CHANGED_SIGNAL: self.on_application_settings_changed_signal,
26+
# SignalCode.APPLICATION_SETTINGS_CHANGED_SIGNAL: self.on_application_settings_changed_signal,
2727
SignalCode.LORA_UPDATED_SIGNAL: self.on_lora_updated_signal,
2828
SignalCode.MODEL_STATUS_CHANGED_SIGNAL: self.on_model_status_changed_signal,
2929
SignalCode.LORA_STATUS_CHANGED: self.on_lora_modified,
@@ -115,7 +115,7 @@ def scan_for_lora(self):
115115
self._load_lora(force_reload=force_reload)
116116

117117
@Slot()
118-
def apply_lora(self):
118+
def on_apply_lora_button_clicked(self):
119119
self._apply_button_enabled = False
120120
self.api.art.lora.update()
121121

@@ -197,6 +197,7 @@ def _load_lora(self, force_reload=False):
197197
if self.search_filter.lower() in lora.name.lower()
198198
]
199199
for lora in filtered_loras:
200+
print("adding lora widget for", lora.name)
200201
self._add_lora(lora)
201202
self.add_spacer()
202203

src/airunner/components/art/gui/widgets/lora/templates/lora_container.ui

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@
280280
<slot>toggle_all(bool)</slot>
281281
<hints>
282282
<hint type="sourcelabel">
283-
<x>562</x>
284-
<y>91</y>
283+
<x>561</x>
284+
<y>95</y>
285285
</hint>
286286
<hint type="destinationlabel">
287287
<x>52</x>
@@ -312,31 +312,15 @@
312312
<slot>scan_for_lora()</slot>
313313
<hints>
314314
<hint type="sourcelabel">
315-
<x>399</x>
316-
<y>830</y>
315+
<x>409</x>
316+
<y>820</y>
317317
</hint>
318318
<hint type="destinationlabel">
319319
<x>486</x>
320320
<y>-1</y>
321321
</hint>
322322
</hints>
323323
</connection>
324-
<connection>
325-
<sender>apply_lora_button</sender>
326-
<signal>clicked()</signal>
327-
<receiver>lora_container</receiver>
328-
<slot>apply_lora()</slot>
329-
<hints>
330-
<hint type="sourcelabel">
331-
<x>532</x>
332-
<y>17</y>
333-
</hint>
334-
<hint type="destinationlabel">
335-
<x>474</x>
336-
<y>-7</y>
337-
</hint>
338-
</hints>
339-
</connection>
340324
</connections>
341325
<slots>
342326
<slot>toggle_all(bool)</slot>

src/airunner/components/art/gui/widgets/lora/templates/lora_container_ui.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def setupUi(self, lora_container):
148148
self.toggleAllLora.toggled.connect(lora_container.toggle_all)
149149
self.lineEdit.textEdited.connect(lora_container.search_text_changed)
150150
self.pushButton.clicked.connect(lora_container.scan_for_lora)
151-
self.apply_lora_button.clicked.connect(lora_container.apply_lora)
152151

153152
QMetaObject.connectSlotsByName(lora_container)
154153
# setupUi

src/airunner/components/art/managers/stablediffusion/mixins/sd_model_loading_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def _load_lora_weights(self, lora: Lora):
251251
lora_base_path,
252252
weight_name=filename,
253253
adapter_name=adapter_name,
254+
scale=getattr(lora, "scale", 1.0),
254255
)
255256
self._loaded_lora[lora.path] = lora
256257
except AttributeError:

src/airunner/components/art/managers/stablediffusion/model_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,13 @@ def load_lora_weights(
127127
) -> bool:
128128
filename = os.path.basename(lora.path)
129129
adapter_name = os.path.splitext(filename)[0].replace(".", "_")
130+
# Scale is stored as 0-100 integer, convert to 0.0-1.0 float
131+
scale = lora.scale / 100.0 if hasattr(lora, 'scale') else 1.0
130132
try:
131133
pipe.load_lora_weights(
132-
lora_base_path, weight_name=filename, adapter_name=adapter_name
134+
lora_base_path, weight_name=filename, adapter_name=adapter_name, scale=scale
133135
)
134-
logger.info(f"Loaded LORA weights: {filename}")
136+
logger.info(f"Loaded LORA weights: {filename} (scale={scale:.2f})")
135137
return True
136138
except Exception as e:
137139
logger.warning(f"Failed to load LORA {filename}: {e}")

src/airunner/components/art/managers/zimage/mixins/zimage_generation_mixin.py

Lines changed: 69 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
from airunner.components.application.exceptions import InterruptedException
77
from airunner.utils.memory import clear_memory
88
from airunner.utils.settings.get_qsettings import get_qsettings
9+
from airunner.components.art.schedulers.flow_match_scheduler_factory import (
10+
is_flow_match_scheduler,
11+
create_flow_match_scheduler,
12+
)
13+
from airunner.enums import Scheduler, ModelType, ModelStatus
14+
15+
from accelerate.hooks import remove_hook_from_module
916

1017

1118
def _aggressive_memory_cleanup():
@@ -57,7 +64,6 @@ def _prepare_data(self, active_rect=None) -> Dict:
5764
self._strip_zimage_incompatible_params(data)
5865
self._enforce_zimage_guidance(data)
5966
data["max_sequence_length"] = 512
60-
self._log_zimage_generation_params(data)
6167
return data
6268

6369
def _strip_zimage_incompatible_params(self, data: Dict) -> None:
@@ -81,34 +87,40 @@ def _enforce_zimage_guidance(self, data: Dict) -> None:
8187
"""
8288
pass
8389

84-
def _log_zimage_generation_params(self, data: Dict) -> None:
85-
"""Log core generation parameters for debugging."""
86-
debug_fields = {
87-
"prompt": data.get("prompt", "MISSING!")[:50] + "...",
88-
"guidance_scale": data.get("guidance_scale", "MISSING!"),
89-
"steps": data.get("num_inference_steps", "MISSING!"),
90-
"size": f"{data.get('width')}x{data.get('height')}",
91-
"max_sequence_length": data.get("max_sequence_length", "MISSING!"),
92-
}
93-
self.logger.info(
94-
"[Z-IMAGE DEBUG] Keys: %s | Values: %s",
95-
list(data.keys()),
96-
debug_fields,
97-
)
98-
9990
def _unload_loras(self):
10091
"""Unload Z-Image LoRA weights if any are loaded.
10192
102-
Z-Image supports LoRA weights through ZImageLoraLoaderMixin.
93+
Z-Image uses additive LoRA that can be removed without model reload.
10394
"""
104-
if hasattr(self._pipe, 'unload_lora_weights'):
95+
self.logger.debug("Unloading Z-Image LoRA weights")
96+
if self._pipe is not None and hasattr(self._pipe, 'unload_lora_weights'):
10597
try:
10698
self._pipe.unload_lora_weights()
107-
self.logger.debug("Unloaded Z-Image LoRA weights")
99+
self.logger.info("✓ Unloaded all Z-Image LoRA weights")
108100
except Exception as e:
109-
self.logger.debug(f"No LoRA weights to unload: {e}")
101+
self.logger.warning(f"Error unloading LoRA weights: {e}")
110102
self._loaded_lora = {}
111103
self._disabled_lora = []
104+
105+
def _disable_lora(self, adapter_name: str):
106+
"""Disable a specific LoRA adapter without removing it.
107+
108+
Args:
109+
adapter_name: Name of the adapter to disable
110+
"""
111+
if self._pipe is not None and hasattr(self._pipe, 'set_lora_enabled'):
112+
self._pipe.set_lora_enabled(adapter_name, False)
113+
self.logger.debug(f"Disabled LoRA: {adapter_name}")
114+
115+
def _enable_lora(self, adapter_name: str):
116+
"""Enable a specific LoRA adapter.
117+
118+
Args:
119+
adapter_name: Name of the adapter to enable
120+
"""
121+
if self._pipe is not None and hasattr(self._pipe, 'set_lora_enabled'):
122+
self._pipe.set_lora_enabled(adapter_name, True)
123+
self.logger.debug(f"Enabled LoRA: {adapter_name}")
112124

113125
def _load_scheduler(self, scheduler_name=None):
114126
"""Load a flow-match scheduler for Z-Image.
@@ -118,11 +130,7 @@ def _load_scheduler(self, scheduler_name=None):
118130
Args:
119131
scheduler_name: Display name of the scheduler to load.
120132
"""
121-
from airunner.components.art.schedulers.flow_match_scheduler_factory import (
122-
is_flow_match_scheduler,
123-
create_flow_match_scheduler,
124-
)
125-
from airunner.enums import Scheduler, ModelType, ModelStatus
133+
# imports moved to module level for performance and clarity
126134

127135
# Get scheduler name
128136
requested_name = (
@@ -232,11 +240,8 @@ def _unload_pipe(self):
232240
return
233241

234242
# Import accelerate hooks removal if available
235-
try:
236-
from accelerate.hooks import remove_hook_from_module
237-
has_accelerate_hooks = True
238-
except ImportError:
239-
has_accelerate_hooks = False
243+
has_accelerate_hooks = remove_hook_from_module is not None
244+
if not has_accelerate_hooks:
240245
self.logger.debug("accelerate.hooks not available, using manual cleanup")
241246

242247
# List of all Z-Image components to clean up (ordered by size)
@@ -259,78 +264,12 @@ def _unload_pipe(self):
259264
self.logger.debug(f"Error removing hook: {e}")
260265
self._pipe._all_hooks.clear()
261266

262-
# Process each component
267+
# Process each component via helper
263268
for component_name in component_names:
264269
component = getattr(self._pipe, component_name, None)
265270
if component is None:
266271
continue
267-
268-
self.logger.debug(f"Cleaning up {component_name}...")
269-
270-
# 1. Remove accelerate hooks using official API
271-
if has_accelerate_hooks:
272-
try:
273-
remove_hook_from_module(component, recurse=True)
274-
self.logger.debug(f"Removed hooks from {component_name}")
275-
except Exception as e:
276-
self.logger.debug(f"Hook removal for {component_name}: {e}")
277-
278-
# 2. Manual hook cleanup for any remaining hooks
279-
if hasattr(component, "_hf_hook"):
280-
try:
281-
hook = component._hf_hook
282-
# Clear any offloaded weights first
283-
if hasattr(hook, "weights_map") and hook.weights_map is not None:
284-
hook.weights_map.clear()
285-
if hasattr(hook, "offload"):
286-
try:
287-
hook.offload(component)
288-
except Exception:
289-
pass
290-
delattr(component, "_hf_hook")
291-
except Exception as e:
292-
self.logger.debug(f"Manual hook cleanup for {component_name}: {e}")
293-
294-
# 3. For models with device_map (quantized models), clear the device_map state
295-
if hasattr(component, "hf_device_map"):
296-
try:
297-
component.hf_device_map = None
298-
except Exception:
299-
pass
300-
301-
# 4. CRITICAL: Do NOT move to CPU - this creates new tensors in CPU RAM
302-
# Instead, delete tensors in-place on their current device
303-
304-
# 5. Clear parameter data to free memory
305-
if hasattr(component, "parameters"):
306-
try:
307-
for param in component.parameters():
308-
param.data = torch.empty(0, device=param.device)
309-
if param.grad is not None:
310-
param.grad = None
311-
except Exception as e:
312-
self.logger.debug(f"Error clearing params for {component_name}: {e}")
313-
314-
# 6. Detach the component from the pipeline
315-
try:
316-
setattr(self._pipe, component_name, None)
317-
except Exception:
318-
pass
319-
320-
# 7. Delete component reference
321-
try:
322-
del component
323-
except Exception:
324-
pass
325-
326-
# 8. Run gc after each large component to free memory immediately
327-
if component_name in ("text_encoder", "transformer"):
328-
_aggressive_memory_cleanup()
329-
330-
# Clear execution device reference
331-
if hasattr(self._pipe, "_execution_device"):
332-
self._pipe._execution_device = None
333-
272+
self._cleanup_pipeline_component(component_name, component, has_accelerate_hooks)
334273
except Exception as e:
335274
self.logger.warning(f"Error during hook removal: {e}")
336275

@@ -347,20 +286,46 @@ def _unload_pipe(self):
347286

348287
self.logger.info("✓ Z-Image pipeline unloaded and memory freed")
349288

350-
def _clear_pipeline_caches(self):
351-
"""Clear internal pipeline caches to free RAM and VRAM.
289+
def _cleanup_pipeline_component(self, component_name: str, component: Any, has_accelerate_hooks: bool) -> None:
290+
"""Cleanup a component attached to the pipeline.
291+
292+
Extraction from the prior implementation to reduce _unload_pipe size.
293+
"""
294+
self.logger.debug(f"Cleaning up {component_name}...")
295+
296+
# 1. Remove accelerate hooks using official API
297+
if has_accelerate_hooks:
298+
try:
299+
remove_hook_from_module(component, recurse=True)
300+
self.logger.debug(f"Removed hooks from {component_name}")
301+
except Exception as e:
302+
self.logger.debug(f"Hook removal for {component_name}: {e}")
303+
304+
# 2. Manual hook cleanup for any remaining hooks
305+
if hasattr(component, "_hf_hook"):
306+
try:
307+
hook = component._hf_hook
308+
if hasattr(hook, "weights_map") and hook.weights_map is not None:
309+
hook.weights_map.clear()
310+
if hasattr(hook, "offload"):
311+
try:
312+
hook.offload(component)
313+
except Exception:
314+
pass
315+
delattr(component, "_hf_hook")
316+
except Exception as e:
317+
self.logger.debug(f"Manual hook cleanup for {component_name}: {e}")
352318

319+
def _clear_pipeline_caches(self):
320+
"""Clear cached tensors and per-component caches on the active pipeline.
321+
353322
This is called after each generation to prevent memory accumulation.
354323
"""
355324
if self._pipe is None:
356325
return
357326

358327
self.logger.debug("Clearing pipeline caches to free RAM")
359328

360-
# Clear any cached tensors on the pipeline
361-
if hasattr(self._pipe, "_callback_tensor_inputs"):
362-
self._pipe._callback_tensor_inputs = None
363-
364329
# For text encoder, clear any cached key/values
365330
text_encoder = getattr(self._pipe, "text_encoder", None)
366331
if text_encoder is not None and hasattr(text_encoder, "past_key_values"):

0 commit comments

Comments
 (0)