Skip to content

Commit 81fbf3f

Browse files
committed
Support edit models in upscale workspace #1896
1 parent cac760e commit 81fbf3f

File tree

3 files changed

+63
-43
lines changed

3 files changed

+63
-43
lines changed

ai_diffusion/model.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from __future__ import annotations
22
import asyncio
3+
import time
4+
import weakref
5+
import uuid
36
from copy import copy
47
from collections import deque
5-
from dataclasses import replace
8+
from dataclasses import dataclass, replace
69
from datetime import datetime
710
from pathlib import Path
811
from enum import Enum
912
from tempfile import TemporaryDirectory
10-
import time
1113
from typing import Any, NamedTuple
1214
from PyQt5.QtCore import QObject, QMetaObject, QUuid, pyqtSignal, Qt
1315
from PyQt5.QtGui import QPainter, QColor, QBrush
14-
import uuid
1516

1617
from . import eventloop, workflow, util
1718
from .api import ConditioningInput, ControlInput, WorkflowKind, WorkflowInput, SamplingInput
@@ -312,8 +313,6 @@ async def _enqueue_job(self, job: Job, input: WorkflowInput, front: bool = False
312313
job.id = await client.enqueue(input, front)
313314

314315
def _prepare_upscale_image(self, dryrun=False):
315-
assert not self.arch.is_edit, "Edit models do not support upscaling"
316-
317316
client = self._connection.client
318317
extent = self._doc.extent
319318
image = self._doc.get_image(Bounds(0, 0, *extent)) if not dryrun else DummyImage(extent)
@@ -324,13 +323,17 @@ def _prepare_upscale_image(self, dryrun=False):
324323
self.report_error(Error(ErrorKind.warning, msg + f": {params.upscale.model}"))
325324
self.upscale.upscaler = params.upscale.model = client.models.default_upscaler
326325
bounds = Bounds(0, 0, *self._doc.extent)
326+
sys_prompt = "4k uhd"
327+
if self.arch.is_edit:
328+
sys_prompt = "Enhance image quality. Preserve original content."
329+
327330
if params.use_prompt and not dryrun:
328331
conditioning, job_regions = process_regions(self.regions, bounds, min_coverage=0)
329332
conditioning.language = self.prompt_translation_language
330333
for region in job_regions:
331334
region.bounds = Bounds.scale(region.bounds, params.factor)
332335
else:
333-
conditioning, job_regions = ConditioningInput("4k uhd"), []
336+
conditioning, job_regions = ConditioningInput(sys_prompt), []
334337
models = client.models.for_arch(self.arch)
335338
has_unblur = models.control.find(ControlMode.blur, allow_universal=True) is not None
336339
if has_unblur and params.unblur_strength > 0.0:
@@ -978,7 +981,8 @@ def get_context(self, model: Model, mask: Mask | None):
978981
return None
979982

980983

981-
class UpscaleParams(NamedTuple):
984+
@dataclass(frozen=True)
985+
class UpscaleParams:
982986
upscale: UpscaleInput
983987
factor: float
984988
use_diffusion: bool
@@ -1019,14 +1023,15 @@ class UpscaleWorkspace(QObject, ObservableProperties):
10191023

10201024
def __init__(self, model: Model):
10211025
super().__init__()
1022-
self._model = model
1026+
self._model = weakref.ref(model)
10231027
self._in_progress = False
10241028
self.use_diffusion_changed.connect(self._update_can_generate)
10251029
self._init_model()
10261030
model._connection.models_changed.connect(self._init_model)
10271031

10281032
def _init_model(self):
1029-
if client := self._model._connection.client_if_connected:
1033+
model = ensure(self._model())
1034+
if client := model._connection.client_if_connected:
10301035
if self.upscaler not in client.models.upscalers:
10311036
self.upscaler = client.models.default_upscaler
10321037

@@ -1046,20 +1051,21 @@ def _update_can_generate(self):
10461051

10471052
@property
10481053
def target_extent(self):
1049-
return self._model.document.extent * self.factor
1054+
return ensure(self._model()).document.extent * self.factor
10501055

10511056
@property
10521057
def params(self):
1058+
model = ensure(self._model())
10531059
overlap = self.tile_overlap if self.tile_overlap_mode is TileOverlapMode.custom else -1
10541060
return UpscaleParams(
10551061
upscale=UpscaleInput(self.upscaler, overlap),
10561062
factor=self.factor,
10571063
use_diffusion=self.use_diffusion,
10581064
unblur_strength=self.unblur_strength,
10591065
use_prompt=self.use_prompt,
1060-
strength=self.strength,
1066+
strength=1.0 if model.arch.is_edit else self.strength,
10611067
target_extent=self.target_extent,
1062-
seed=self._model.seed if self._model.fixed_seed else workflow.generate_seed(),
1068+
seed=model.seed if model.fixed_seed else workflow.generate_seed(),
10631069
)
10641070

10651071

@@ -1127,7 +1133,7 @@ class LiveWorkspace(QObject, ObservableProperties):
11271133

11281134
def __init__(self, model: Model):
11291135
super().__init__()
1130-
self._model = model
1136+
self._model = weakref.ref(model)
11311137
self._scheduler = LiveScheduler()
11321138
self._result: Image | None = None
11331139
self._result_composition: Image | None = None
@@ -1138,19 +1144,23 @@ def __init__(self, model: Model):
11381144
self._keyframes: list[Path] = []
11391145
model.jobs.job_finished.connect(self.handle_job_finished)
11401146

1147+
@property
1148+
def model(self):
1149+
return ensure(self._model())
1150+
11411151
def toggle(self, active: bool):
11421152
if self.is_active != active:
11431153
self._is_active = active
11441154
self.is_active_changed.emit(active)
11451155
if active:
1146-
eventloop.run(_report_errors(self._model, self._continue_generating()))
1156+
eventloop.run(_report_errors(self.model, self._continue_generating()))
11471157
else:
11481158
self.is_recording = False
11491159

11501160
def toggle_record(self, active: bool):
11511161
if self.is_recording != active:
11521162
if active and not self._start_recording():
1153-
self._model.report_error(
1163+
self.model.report_error(
11541164
_("Cannot save recorded frames, document must be saved first!")
11551165
)
11561166
return
@@ -1164,16 +1174,16 @@ def handle_job_finished(self, job: Job):
11641174
if job.kind is JobKind.live_preview:
11651175
if len(job.results) > 0:
11661176
self.set_result(job.results[0], job.params)
1167-
self.is_active = self._is_active and self._model.document.is_active
1177+
self.is_active = self._is_active and self.model.document.is_active
11681178
self._scheduler.notify_generation_finished()
1169-
eventloop.run(_report_errors(self._model, self._continue_generating()))
1179+
eventloop.run(_report_errors(self.model, self._continue_generating()))
11701180

11711181
async def _continue_generating(self):
11721182
while self.is_active:
1173-
if self._model.document.is_active:
1174-
new_input, job_params = self._model._prepare_live_workflow()
1183+
if self.model.document.is_active:
1184+
new_input, job_params = self.model._prepare_live_workflow()
11751185
if self._scheduler.should_generate(new_input):
1176-
await self._model._generate_live(new_input, job_params)
1186+
await self.model._generate_live(new_input, job_params)
11771187
self._scheduler.notify_generation_started()
11781188
return
11791189
await asyncio.sleep(self._scheduler.poll_rate)
@@ -1182,7 +1192,7 @@ def apply_result(self, layer_only=False):
11821192
assert self.result is not None and self._result_params is not None
11831193
params = copy(self._result_params)
11841194
if layer_only and len(self._result_params.regions) > 0:
1185-
active = Region.link_target(self._model.layers.active).id_string
1195+
active = Region.link_target(self.model.layers.active).id_string
11861196
if region := next((r for r in params.regions if r.layer_id == active), None):
11871197
params.regions = [region]
11881198

@@ -1191,10 +1201,10 @@ def apply_result(self, layer_only=False):
11911201
if layer_only:
11921202
behavior = ApplyBehavior.layer
11931203
region_behavior = ApplyRegionBehavior.layer_group
1194-
self._model.apply_result(self.result, params, behavior, region_behavior)
1204+
self.model.apply_result(self.result, params, behavior, region_behavior)
11951205

11961206
if settings.new_seed_after_apply:
1197-
self._model.generate_seed()
1207+
self.model.generate_seed()
11981208

11991209
@property
12001210
def result(self):
@@ -1205,7 +1215,7 @@ def result_composition(self):
12051215
return self._result_composition
12061216

12071217
def set_result(self, value: Image, params: JobParams):
1208-
canvas = self._model._get_current_image(params.bounds)
1218+
canvas = self.model._get_current_image(params.bounds)
12091219
painter = QPainter(canvas._qimage)
12101220
painter.setCompositionMode(QPainter.CompositionMode.CompositionMode_Multiply)
12111221
painter.setBrush(QBrush(QColor(0, 0, 96, 192), Qt.BrushStyle.DiagCrossPattern))
@@ -1223,7 +1233,7 @@ def set_result(self, value: Image, params: JobParams):
12231233
self._save_frame(value, params.bounds)
12241234

12251235
def _start_recording(self):
1226-
doc_filename = self._model.document.filename
1236+
doc_filename = self.model.document.filename
12271237
if doc_filename:
12281238
path = Path(doc_filename)
12291239
folder = path.parent / f"{path.with_suffix('.live-frames')}"
@@ -1241,7 +1251,7 @@ def _save_frame(self, image: Image, bounds: Bounds):
12411251
filename = self._keyframes_folder / f"frame-{self._keyframe_index}.webp"
12421252
self._keyframe_index += 1
12431253

1244-
extent = self._model.document.extent
1254+
extent = self.model.document.extent
12451255
if bounds is not None and bounds.extent != extent:
12461256
image = Image.crop(image, bounds)
12471257
image.save(filename)
@@ -1250,10 +1260,10 @@ def _save_frame(self, image: Image, bounds: Bounds):
12501260
def _import_animation(self):
12511261
if len(self._keyframes) == 0:
12521262
return # button toggled without recording a frame in between
1253-
self._model.document.import_animation(self._keyframes, self._keyframe_start)
1263+
self.model.document.import_animation(self._keyframes, self._keyframe_start)
12541264
start, end = self._keyframe_start, self._keyframe_start + len(self._keyframes)
1255-
prompt = self._model.regions.active_or_root.positive
1256-
self._model.layers.active.name = f"[Rec] {start}-{end}: {prompt}"
1265+
prompt = self.model.regions.active_or_root.positive
1266+
self.model.layers.active.name = f"[Rec] {start}-{end}: {prompt}"
12571267
self._keyframes = []
12581268

12591269

ai_diffusion/ui/upscale.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(self):
140140
unblur_layout = QHBoxLayout()
141141
unblur_layout.addWidget(QLabel(_("Image guidance"), self), 1)
142142
unblur_layout.addWidget(self.unblur_slider, 3)
143-
root.connection.models_changed.connect(self._update_unblur_enabled)
143+
root.connection.models_changed.connect(self._update_style)
144144

145145
self.overlap_custom_combo = QComboBox(self)
146146
self.overlap_custom_combo.addItem(_("Automatic"), TileOverlapMode.auto)
@@ -231,12 +231,12 @@ def model(self, model: Model):
231231
model.regions.added.connect(self._update_prompt),
232232
model.regions.removed.connect(self._update_prompt),
233233
model.progress_changed.connect(self.update_progress),
234-
model.style_changed.connect(self._update_unblur_enabled),
234+
model.style_changed.connect(self._update_style),
235235
]
236236
self.upscale_button.model = model
237237
self.queue_button.model = model
238238
self._update_prompt()
239-
self._update_unblur_enabled()
239+
self._update_style()
240240
self._update_overlap()
241241
self.update_progress()
242242

@@ -275,20 +275,27 @@ def _update_overlap(self):
275275
self.model.upscale.tile_overlap_mode is TileOverlapMode.custom
276276
)
277277

278-
def _update_unblur_enabled(self):
279-
has_unblur = False
280-
if client := root.connection.client_if_connected:
281-
models = client.models.for_arch(self.model.arch)
282-
has_unblur = models.control.find(ControlMode.blur, allow_universal=True) is not None
283-
self.unblur_slider.setEnabled(has_unblur)
284-
if not has_unblur:
285-
self.unblur_slider.setToolTip(_("The tile/unblur control model is not installed."))
278+
def _update_style(self):
279+
arch = self.model.arch
280+
if arch.is_edit:
281+
tooltip = _("Not supported for edit models")
282+
self.strength_slider.setEnabled(False)
283+
self.strength_slider.setToolTip(tooltip)
284+
self.unblur_slider.setEnabled(False)
285+
self.unblur_slider.setToolTip(tooltip)
286286
else:
287-
self.unblur_slider.setToolTip(
288-
_(
287+
has_unblur = False
288+
if client := root.connection.client_if_connected:
289+
models = client.models.for_arch(self.model.arch)
290+
has_unblur = models.control.find(ControlMode.blur, allow_universal=True) is not None
291+
self.unblur_slider.setEnabled(has_unblur)
292+
if not has_unblur:
293+
tooltip = _("The tile/unblur control model is not installed.")
294+
else:
295+
tooltip = _(
289296
"When enabled, the low resolution image is used as guidance for refining the upscaled image.\nThis produces results which are closer to the original while enhancing local details."
290297
)
291-
)
298+
self.unblur_slider.setToolTip(tooltip)
292299

293300
def _update_prompt(self):
294301
self.use_prompt_value.setText(_("On") if self.model.upscale.use_prompt else _("Off"))

ai_diffusion/workflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,9 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds):
12711271

12721272
latent = vae_encode(w, vae, tile_image, checkpoint.tiled_vae)
12731273
latent = w.set_latent_noise_mask(latent, tile_mask)
1274+
positive = apply_edit_conditioning(
1275+
w, positive, tile_image, latent, control, vae, models.arch, checkpoint.tiled_vae
1276+
)
12741277
sampler = w.sampler_custom_advanced(
12751278
tile_model, positive, negative, latent, models.arch, **_sampler_params(sampling)
12761279
)

0 commit comments

Comments
 (0)