Skip to content

Commit 13d5b08

Browse files
committed
Show warning message when a LoRA doesn't match the base model
1 parent 38f20d4 commit 13d5b08

File tree

4 files changed

+39
-6
lines changed

4 files changed

+39
-6
lines changed

ai_diffusion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Generative AI plugin for Krita"""
22

3-
__version__ = "1.32.0"
3+
__version__ = "1.33.0"
44

55
import importlib.util
66

ai_diffusion/cloud_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ async def _process_job(self, job: JobInfo):
189189
images = await self.receive_images(output["images"])
190190
pose = output.get("pose", None)
191191
log.info(f"{job} completed, got {len(images)} images{', got pose' if pose else ''}")
192-
yield ClientMessage(ClientEvent.finished, job.local_id, 1, images, pose)
192+
lora_warning = output.get("lora_warning", False)
193+
if lora_warning:
194+
log.warning(f"{job} encountered LoRA that could not be applied to the checkpoint")
195+
error = "incompatible_lora" if lora_warning else None
196+
yield ClientMessage(ClientEvent.finished, job.local_id, 1, images, pose, error=error)
193197

194198
elif response["status"] == "FAILED":
195199
err_msg, err_trace = _extract_error(response, job.remote_id)

ai_diffusion/model.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,15 @@ class ProgressKind(Enum):
5252

5353
class ErrorKind(Enum):
5454
none = 0
55-
plugin_error = 1
56-
server_error = 2
57-
insufficient_funds = 3
55+
plugin_error = 100
56+
server_error = 200
57+
insufficient_funds = 201
58+
warning = 300
59+
incompatible_lora = 301
60+
61+
@property
62+
def is_warning(self):
63+
return self.value >= ErrorKind.warning.value
5864

5965

6066
class Error(NamedTuple):
@@ -65,6 +71,11 @@ class Error(NamedTuple):
6571
def __bool__(self):
6672
return self.kind is not ErrorKind.none
6773

74+
@staticmethod
75+
def from_string(s: str, fallback: ErrorKind | None = None):
76+
kind = ErrorKind[s] if s in ErrorKind.__members__ else fallback or ErrorKind.warning
77+
return Error(kind, s)
78+
6879

6980
no_error = Error(ErrorKind.none, "")
7081

@@ -463,7 +474,7 @@ def cancel(self, active=False, queued=False):
463474

464475
def report_error(self, error: Error | str):
465476
if isinstance(error, str):
466-
error = Error(ErrorKind.server_error, error)
477+
error = Error.from_string(error, ErrorKind.server_error)
467478
self.error = error
468479
self.live.is_active = False
469480
self.custom.is_live = False
@@ -493,6 +504,8 @@ def handle_message(self, message: ClientMessage):
493504
elif message.event is ClientEvent.output:
494505
self.custom.show_output(message.result)
495506
elif message.event is ClientEvent.finished:
507+
if message.error: # successful jobs may have encountered some warnings
508+
self.report_error(Error.from_string(message.error, ErrorKind.warning))
496509
if message.images:
497510
self.jobs.set_results(job, message.images)
498511
if job.kind is JobKind.control_layer:

ai_diffusion/ui/widget.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,8 @@ def error(self, error: Error):
878878
self._original_error = error.message if error else ""
879879
if error.kind is ErrorKind.insufficient_funds:
880880
self._show_payment_error(error.data)
881+
elif error.kind.is_warning:
882+
self._show_warning(error.kind, error.message)
881883
elif error:
882884
self._show_error(error.message)
883885

@@ -897,6 +899,20 @@ def _show_error(self, text: str):
897899
self._copy_button.setVisible(True)
898900
self.setVisible(True)
899901

902+
def _show_warning(self, kind: ErrorKind, text: str):
903+
self.reset(theme.yellow)
904+
if kind is ErrorKind.incompatible_lora:
905+
text = (
906+
_(
907+
"Selected LoRA model could not be applied. Please make sure it is compatible with the checkpoint base model you are using."
908+
)
909+
+ " <a href='https://docs.interstice.cloud/base-models'>"
910+
+ _("Lean more")
911+
+ "</a>"
912+
)
913+
self._label.setText(text)
914+
self.setVisible(True)
915+
900916
def _show_payment_error(self, data: dict[str, Any] | None):
901917
self.reset(theme.yellow)
902918
message = "Insufficient funds"

0 commit comments

Comments
 (0)