Skip to content

Commit 0a7f43c

Browse files
authored
Merge pull request #1961 from Capsize-Games/develop
Develop
2 parents ff9c488 + c8f6667 commit 0a7f43c

File tree

10 files changed

+674
-182
lines changed

10 files changed

+674
-182
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@
187187

188188
setup(
189189
name="airunner",
190-
version="5.4.1",
190+
version="5.4.2",
191191
author="Capsize LLC",
192192
description="Run local opensource AI models (Stable Diffusion, LLMs, TTS, STT, chatbots) in a lightweight Python GUI",
193193
long_description=open("README.md", "r", encoding="utf-8").read(),

src/airunner/components/art/managers/stablediffusion/base_diffusers_model_manager.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,13 @@ def load(self):
288288
- DeepCache helper
289289
- Memory optimizations
290290
"""
291+
self.logger.debug(
292+
f"[LOAD ENTRY] sd_is_loading={self.sd_is_loading}, "
293+
f"model_is_loaded={self.model_is_loaded}, "
294+
f"model_status={self.model_status}, model_type={self.model_type}"
295+
)
291296
if self.sd_is_loading or self.model_is_loaded:
297+
self.logger.debug("[LOAD ENTRY] Returning early - already loading or loaded")
292298
return
293299
if self.model_path is None or self.model_path == "":
294300
self.logger.error("No model selected")
@@ -327,7 +333,9 @@ def load(self):
327333

328334
self.load_controlnet()
329335

336+
self.logger.debug("[LOAD] About to call _load_pipe()")
330337
if self._load_pipe():
338+
self.logger.debug("[LOAD] _load_pipe() returned True, continuing load sequence")
331339
self._send_pipeline_loaded_signal()
332340
self._move_pipe_to_device()
333341
self._load_scheduler()
@@ -507,8 +515,20 @@ def _load_pipe(self) -> bool:
507515
Returns:
508516
True if loaded successfully, False otherwise
509517
"""
518+
self.logger.debug("[_load_pipe] ENTERING METHOD")
519+
try:
520+
pipeline_class = self._pipeline_class
521+
self.logger.debug(f"[_load_pipe] pipeline_class={pipeline_class}")
522+
section = self.section
523+
self.logger.debug(f"[_load_pipe] section={section}")
524+
except Exception as e:
525+
self.logger.error(f"[_load_pipe] Error accessing properties: {e}")
526+
import traceback
527+
self.logger.error(traceback.format_exc())
528+
return False
529+
510530
self.logger.debug(
511-
f"Loading pipe {self._pipeline_class} for {self.section}"
531+
f"Loading pipe {pipeline_class} for {section}"
512532
)
513533
self.change_model_status(self.model_type, ModelStatus.LOADING)
514534
data = self._prepare_pipe_data()

src/airunner/components/art/managers/zimage/mixins/zimage_pipeline_loading_mixin.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ def _set_pipe(self, config_path: str, data: Dict):
9696

9797
# Debug: verify _pipe was set
9898
self.logger.info(f"[ZIMAGE DEBUG] After _set_pipe: self._pipe={self._pipe}, self={id(self)}")
99+
100+
# Load LoRA adapters if available for this pipeline
101+
try:
102+
if hasattr(self, "_load_lora") and self._pipe is not None:
103+
self.logger.info("[ZIMAGE] Loading LoRA adapters")
104+
self._load_lora()
105+
except Exception as exc: # pragma: no cover - defensive logging
106+
self.logger.warning(f"[ZIMAGE] Failed to load LoRA adapters: {exc}")
99107

100108
_clear_gpu_memory()
101109

src/airunner/components/art/managers/zimage/native/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
from airunner.components.art.managers.zimage.native.zimage_native_wrapper import (
5151
NativePipelineWrapper,
5252
)
53+
from airunner.components.art.managers.zimage.native.native_lora import (
54+
NativeLoraLoader,
55+
load_lora_into_transformer,
56+
load_lora_state_dict,
57+
)
5358

5459
__all__ = [
5560
# FP8 Operations
@@ -85,4 +90,9 @@
8590
"SimpleTextEncoder",
8691
# Pipeline
8792
"ZImageNativePipeline",
93+
"NativePipelineWrapper",
94+
# LoRA
95+
"NativeLoraLoader",
96+
"load_lora_into_transformer",
97+
"load_lora_state_dict",
8898
]

src/airunner/components/art/managers/zimage/native/fp8_ops.py

Lines changed: 39 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,10 @@ class FP8Linear(nn.Module):
618618
619619
This layer can hold FP8 weights with per-layer scale factors and
620620
automatically dequantizes during forward pass for compatibility.
621+
622+
The FP8 weight is stored as `fp8_weight_storage` to avoid nn.Module's
623+
special handling of 'weight'. Access via `.weight` property is provided
624+
for compatibility with LoRA loaders.
621625
"""
622626

623627
def __init__(
@@ -627,6 +631,7 @@ def __init__(
627631
bias: bool = True,
628632
device: Optional[torch.device] = None,
629633
dtype: Optional[torch.dtype] = None,
634+
compute_dtype: Optional[torch.dtype] = None,
630635
):
631636
"""
632637
Initialize FP8 linear layer.
@@ -636,24 +641,49 @@ def __init__(
636641
out_features: Output feature dimension
637642
bias: Whether to include bias
638643
device: Target device
639-
dtype: Compute dtype (for activations)
644+
dtype: Compute dtype (for activations) - deprecated, use compute_dtype
645+
compute_dtype: Compute dtype (for activations)
640646
"""
641647
super().__init__()
642648

643649
self.in_features = in_features
644650
self.out_features = out_features
645-
self.compute_dtype = dtype or torch.bfloat16
651+
self.compute_dtype = compute_dtype or dtype or torch.bfloat16
646652

647-
# Initialize as empty - will be filled during checkpoint loading
648-
self.weight: Optional[QuantizedTensor] = None
649-
self.bias: Optional[torch.Tensor] = None
653+
# Use unique name to avoid nn.Module intercepting 'weight'
654+
self.fp8_weight_storage: Optional[QuantizedTensor] = None
655+
self._has_bias = bias
650656

651657
if bias:
652658
self.register_buffer(
653659
'_bias',
654660
torch.zeros(out_features, device=device, dtype=self.compute_dtype)
655661
)
656662

663+
@property
664+
def weight(self) -> Optional[QuantizedTensor]:
665+
"""Get the FP8 quantized weight (for LoRA compatibility)."""
666+
return self.fp8_weight_storage
667+
668+
@weight.setter
669+
def weight(self, value: Optional[QuantizedTensor]):
670+
"""Set the FP8 quantized weight."""
671+
self.fp8_weight_storage = value
672+
673+
@property
674+
def bias(self) -> Optional[torch.Tensor]:
675+
"""Get bias tensor."""
676+
return self._bias if self._has_bias and hasattr(self, '_bias') else None
677+
678+
@bias.setter
679+
def bias(self, value: Optional[torch.Tensor]):
680+
"""Set bias tensor."""
681+
if value is not None and self._has_bias:
682+
if hasattr(self, '_bias'):
683+
self._bias.copy_(value)
684+
else:
685+
self.register_buffer('_bias', value)
686+
657687
def set_fp8_weight(
658688
self,
659689
fp8_weight: torch.Tensor,
@@ -668,7 +698,7 @@ def set_fp8_weight(
668698
scale: Per-layer scale factor
669699
orig_dtype: Original dtype for dequantization
670700
"""
671-
self.weight = QuantizedTensor.from_fp8_with_scale(
701+
self.fp8_weight_storage = QuantizedTensor.from_fp8_with_scale(
672702
fp8_weight, scale, orig_dtype
673703
)
674704

@@ -682,13 +712,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
682712
Returns:
683713
Output tensor
684714
"""
685-
if self.weight is None:
715+
if self.fp8_weight_storage is None:
686716
raise RuntimeError("Weight not set. Call set_fp8_weight first.")
687717

688718
# Dequantize weight for forward pass
689-
weight = self.weight.dequantize().to(x.dtype)
719+
weight = self.fp8_weight_storage.dequantize().to(x.dtype)
690720

691-
bias = self.bias if hasattr(self, '_bias') and self._bias is not None else None
721+
bias = self._bias if self._has_bias and hasattr(self, '_bias') else None
692722
if bias is not None:
693723
bias = bias.to(x.dtype)
694724

@@ -748,169 +778,3 @@ def is_fp8_scaled_checkpoint(state_dict: Dict[str, torch.Tensor]) -> bool:
748778
return True
749779

750780
return has_fp8 and has_scale
751-
752-
753-
class FP8Linear(nn.Module):
754-
"""
755-
Linear layer that stores weights in FP8 format and dequantizes on forward.
756-
757-
This allows ~50% memory savings compared to BF16 while maintaining
758-
numerical accuracy through per-tensor scaling.
759-
760-
Args:
761-
in_features: Input feature dimension
762-
out_features: Output feature dimension
763-
bias: Whether to include bias
764-
device: Target device
765-
compute_dtype: Dtype for computation (default: bfloat16)
766-
"""
767-
768-
def __init__(
769-
self,
770-
in_features: int,
771-
out_features: int,
772-
bias: bool = True,
773-
device: Optional[torch.device] = None,
774-
compute_dtype: torch.dtype = torch.bfloat16,
775-
):
776-
super().__init__()
777-
self.in_features = in_features
778-
self.out_features = out_features
779-
self.compute_dtype = compute_dtype
780-
781-
# FP8 weight storage
782-
self.register_buffer(
783-
'weight_fp8',
784-
torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn, device=device)
785-
)
786-
# Scale is always a scalar but store as 0-dim tensor
787-
self.register_buffer(
788-
'weight_scale',
789-
torch.tensor(1.0, dtype=torch.float32, device=device)
790-
)
791-
792-
# Optional bias in compute dtype
793-
if bias:
794-
self.bias = nn.Parameter(
795-
torch.zeros(out_features, dtype=compute_dtype, device=device)
796-
)
797-
else:
798-
self.register_parameter('bias', None)
799-
800-
def set_fp8_weight(
801-
self,
802-
weight_fp8: torch.Tensor,
803-
scale: torch.Tensor,
804-
) -> None:
805-
"""
806-
Set the FP8 weight and scale.
807-
808-
Args:
809-
weight_fp8: Weight tensor in FP8 format
810-
scale: Scale factor for dequantization (scalar)
811-
"""
812-
self.weight_fp8.copy_(weight_fp8)
813-
# Handle scalar scale - extract the value if it's a tensor
814-
if scale.numel() == 1:
815-
self.weight_scale.fill_(scale.item())
816-
else:
817-
self.weight_scale.fill_(scale.flatten()[0].item())
818-
819-
def forward(self, x: torch.Tensor) -> torch.Tensor:
820-
"""
821-
Forward pass with on-the-fly dequantization.
822-
823-
Args:
824-
x: Input tensor
825-
826-
Returns:
827-
Output tensor
828-
"""
829-
# Dequantize weight: fp8 * scale -> compute_dtype
830-
weight = self.weight_fp8.to(self.compute_dtype) * self.weight_scale.to(self.compute_dtype)
831-
832-
# Convert input to compute dtype if needed
833-
if x.dtype != self.compute_dtype:
834-
x = x.to(self.compute_dtype)
835-
836-
# Standard linear operation
837-
return F.linear(x, weight, self.bias)
838-
839-
@classmethod
840-
def from_linear(
841-
cls,
842-
linear: nn.Linear,
843-
device: Optional[torch.device] = None,
844-
) -> 'FP8Linear':
845-
"""
846-
Create FP8Linear from an existing nn.Linear layer.
847-
848-
Quantizes weights to FP8 format with computed scale.
849-
850-
Args:
851-
linear: Source linear layer
852-
device: Target device
853-
854-
Returns:
855-
FP8Linear layer with quantized weights
856-
"""
857-
device = device or linear.weight.device
858-
has_bias = linear.bias is not None
859-
compute_dtype = linear.weight.dtype
860-
861-
fp8_linear = cls(
862-
linear.in_features,
863-
linear.out_features,
864-
bias=has_bias,
865-
device=device,
866-
compute_dtype=compute_dtype,
867-
)
868-
869-
# Quantize weight
870-
qdata, params = TensorCoreFP8Layout.quantize(
871-
linear.weight.detach(),
872-
scale="recalculate",
873-
)
874-
fp8_linear.set_fp8_weight(qdata.to(device), params['scale'].to(device))
875-
876-
# Copy bias
877-
if has_bias:
878-
fp8_linear.bias.data.copy_(linear.bias.data)
879-
880-
return fp8_linear
881-
882-
def extra_repr(self) -> str:
883-
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
884-
885-
886-
def convert_linear_to_fp8(
887-
module: nn.Module,
888-
device: Optional[torch.device] = None,
889-
skip_patterns: Optional[list] = None,
890-
) -> nn.Module:
891-
"""
892-
Recursively convert all nn.Linear layers to FP8Linear.
893-
894-
Args:
895-
module: Module to convert
896-
device: Target device
897-
skip_patterns: List of name patterns to skip
898-
899-
Returns:
900-
Module with converted layers
901-
"""
902-
skip_patterns = skip_patterns or []
903-
904-
for name, child in list(module.named_children()):
905-
# Check skip patterns
906-
should_skip = any(pattern in name for pattern in skip_patterns)
907-
908-
if isinstance(child, nn.Linear) and not should_skip:
909-
# Convert to FP8
910-
fp8_linear = FP8Linear.from_linear(child, device=device)
911-
setattr(module, name, fp8_linear)
912-
else:
913-
# Recurse
914-
convert_linear_to_fp8(child, device=device, skip_patterns=skip_patterns)
915-
916-
return module

0 commit comments

Comments
 (0)