11from __future__ import annotations
22import asyncio
3+ import time
4+ import weakref
5+ import uuid
36from copy import copy
47from collections import deque
5- from dataclasses import replace
8+ from dataclasses import dataclass , replace
69from datetime import datetime
710from pathlib import Path
811from enum import Enum
912from tempfile import TemporaryDirectory
10- import time
1113from typing import Any , NamedTuple
1214from PyQt5 .QtCore import QObject , QMetaObject , QUuid , pyqtSignal , Qt
1315from PyQt5 .QtGui import QPainter , QColor , QBrush
14- import uuid
1516
1617from . import eventloop , workflow , util
1718from .api import ConditioningInput , ControlInput , WorkflowKind , WorkflowInput , SamplingInput
@@ -312,8 +313,6 @@ async def _enqueue_job(self, job: Job, input: WorkflowInput, front: bool = False
312313 job .id = await client .enqueue (input , front )
313314
314315 def _prepare_upscale_image (self , dryrun = False ):
315- assert not self .arch .is_edit , "Edit models do not support upscaling"
316-
317316 client = self ._connection .client
318317 extent = self ._doc .extent
319318 image = self ._doc .get_image (Bounds (0 , 0 , * extent )) if not dryrun else DummyImage (extent )
@@ -324,13 +323,17 @@ def _prepare_upscale_image(self, dryrun=False):
324323 self .report_error (Error (ErrorKind .warning , msg + f": { params .upscale .model } " ))
325324 self .upscale .upscaler = params .upscale .model = client .models .default_upscaler
326325 bounds = Bounds (0 , 0 , * self ._doc .extent )
326+ sys_prompt = "4k uhd"
327+ if self .arch .is_edit :
328+ sys_prompt = "Enhance image quality. Preserve original content."
329+
327330 if params .use_prompt and not dryrun :
328331 conditioning , job_regions = process_regions (self .regions , bounds , min_coverage = 0 )
329332 conditioning .language = self .prompt_translation_language
330333 for region in job_regions :
331334 region .bounds = Bounds .scale (region .bounds , params .factor )
332335 else :
333- conditioning , job_regions = ConditioningInput ("4k uhd" ), []
336+ conditioning , job_regions = ConditioningInput (sys_prompt ), []
334337 models = client .models .for_arch (self .arch )
335338 has_unblur = models .control .find (ControlMode .blur , allow_universal = True ) is not None
336339 if has_unblur and params .unblur_strength > 0.0 :
@@ -978,7 +981,8 @@ def get_context(self, model: Model, mask: Mask | None):
978981 return None
979982
980983
981- class UpscaleParams (NamedTuple ):
984+ @dataclass (frozen = True )
985+ class UpscaleParams :
982986 upscale : UpscaleInput
983987 factor : float
984988 use_diffusion : bool
@@ -1019,14 +1023,15 @@ class UpscaleWorkspace(QObject, ObservableProperties):
10191023
10201024 def __init__ (self , model : Model ):
10211025 super ().__init__ ()
1022- self ._model = model
1026+ self ._model = weakref . ref ( model )
10231027 self ._in_progress = False
10241028 self .use_diffusion_changed .connect (self ._update_can_generate )
10251029 self ._init_model ()
10261030 model ._connection .models_changed .connect (self ._init_model )
10271031
10281032 def _init_model (self ):
1029- if client := self ._model ._connection .client_if_connected :
1033+ model = ensure (self ._model ())
1034+ if client := model ._connection .client_if_connected :
10301035 if self .upscaler not in client .models .upscalers :
10311036 self .upscaler = client .models .default_upscaler
10321037
@@ -1046,20 +1051,21 @@ def _update_can_generate(self):
10461051
10471052 @property
10481053 def target_extent (self ):
1049- return self ._model .document .extent * self .factor
1054+ return ensure ( self ._model ()) .document .extent * self .factor
10501055
10511056 @property
10521057 def params (self ):
1058+ model = ensure (self ._model ())
10531059 overlap = self .tile_overlap if self .tile_overlap_mode is TileOverlapMode .custom else - 1
10541060 return UpscaleParams (
10551061 upscale = UpscaleInput (self .upscaler , overlap ),
10561062 factor = self .factor ,
10571063 use_diffusion = self .use_diffusion ,
10581064 unblur_strength = self .unblur_strength ,
10591065 use_prompt = self .use_prompt ,
1060- strength = self .strength ,
1066+ strength = 1.0 if model . arch . is_edit else self .strength ,
10611067 target_extent = self .target_extent ,
1062- seed = self . _model . seed if self . _model .fixed_seed else workflow .generate_seed (),
1068+ seed = model . seed if model .fixed_seed else workflow .generate_seed (),
10631069 )
10641070
10651071
@@ -1127,7 +1133,7 @@ class LiveWorkspace(QObject, ObservableProperties):
11271133
11281134 def __init__ (self , model : Model ):
11291135 super ().__init__ ()
1130- self ._model = model
1136+ self ._model = weakref . ref ( model )
11311137 self ._scheduler = LiveScheduler ()
11321138 self ._result : Image | None = None
11331139 self ._result_composition : Image | None = None
@@ -1138,19 +1144,23 @@ def __init__(self, model: Model):
11381144 self ._keyframes : list [Path ] = []
11391145 model .jobs .job_finished .connect (self .handle_job_finished )
11401146
1147+ @property
1148+ def model (self ):
1149+ return ensure (self ._model ())
1150+
11411151 def toggle (self , active : bool ):
11421152 if self .is_active != active :
11431153 self ._is_active = active
11441154 self .is_active_changed .emit (active )
11451155 if active :
1146- eventloop .run (_report_errors (self ._model , self ._continue_generating ()))
1156+ eventloop .run (_report_errors (self .model , self ._continue_generating ()))
11471157 else :
11481158 self .is_recording = False
11491159
11501160 def toggle_record (self , active : bool ):
11511161 if self .is_recording != active :
11521162 if active and not self ._start_recording ():
1153- self ._model .report_error (
1163+ self .model .report_error (
11541164 _ ("Cannot save recorded frames, document must be saved first!" )
11551165 )
11561166 return
@@ -1164,16 +1174,16 @@ def handle_job_finished(self, job: Job):
11641174 if job .kind is JobKind .live_preview :
11651175 if len (job .results ) > 0 :
11661176 self .set_result (job .results [0 ], job .params )
1167- self .is_active = self ._is_active and self ._model .document .is_active
1177+ self .is_active = self ._is_active and self .model .document .is_active
11681178 self ._scheduler .notify_generation_finished ()
1169- eventloop .run (_report_errors (self ._model , self ._continue_generating ()))
1179+ eventloop .run (_report_errors (self .model , self ._continue_generating ()))
11701180
11711181 async def _continue_generating (self ):
11721182 while self .is_active :
1173- if self ._model .document .is_active :
1174- new_input , job_params = self ._model ._prepare_live_workflow ()
1183+ if self .model .document .is_active :
1184+ new_input , job_params = self .model ._prepare_live_workflow ()
11751185 if self ._scheduler .should_generate (new_input ):
1176- await self ._model ._generate_live (new_input , job_params )
1186+ await self .model ._generate_live (new_input , job_params )
11771187 self ._scheduler .notify_generation_started ()
11781188 return
11791189 await asyncio .sleep (self ._scheduler .poll_rate )
@@ -1182,7 +1192,7 @@ def apply_result(self, layer_only=False):
11821192 assert self .result is not None and self ._result_params is not None
11831193 params = copy (self ._result_params )
11841194 if layer_only and len (self ._result_params .regions ) > 0 :
1185- active = Region .link_target (self ._model .layers .active ).id_string
1195+ active = Region .link_target (self .model .layers .active ).id_string
11861196 if region := next ((r for r in params .regions if r .layer_id == active ), None ):
11871197 params .regions = [region ]
11881198
@@ -1191,10 +1201,10 @@ def apply_result(self, layer_only=False):
11911201 if layer_only :
11921202 behavior = ApplyBehavior .layer
11931203 region_behavior = ApplyRegionBehavior .layer_group
1194- self ._model .apply_result (self .result , params , behavior , region_behavior )
1204+ self .model .apply_result (self .result , params , behavior , region_behavior )
11951205
11961206 if settings .new_seed_after_apply :
1197- self ._model .generate_seed ()
1207+ self .model .generate_seed ()
11981208
11991209 @property
12001210 def result (self ):
@@ -1205,7 +1215,7 @@ def result_composition(self):
12051215 return self ._result_composition
12061216
12071217 def set_result (self , value : Image , params : JobParams ):
1208- canvas = self ._model ._get_current_image (params .bounds )
1218+ canvas = self .model ._get_current_image (params .bounds )
12091219 painter = QPainter (canvas ._qimage )
12101220 painter .setCompositionMode (QPainter .CompositionMode .CompositionMode_Multiply )
12111221 painter .setBrush (QBrush (QColor (0 , 0 , 96 , 192 ), Qt .BrushStyle .DiagCrossPattern ))
@@ -1223,7 +1233,7 @@ def set_result(self, value: Image, params: JobParams):
12231233 self ._save_frame (value , params .bounds )
12241234
12251235 def _start_recording (self ):
1226- doc_filename = self ._model .document .filename
1236+ doc_filename = self .model .document .filename
12271237 if doc_filename :
12281238 path = Path (doc_filename )
12291239 folder = path .parent / f"{ path .with_suffix ('.live-frames' )} "
@@ -1241,7 +1251,7 @@ def _save_frame(self, image: Image, bounds: Bounds):
12411251 filename = self ._keyframes_folder / f"frame-{ self ._keyframe_index } .webp"
12421252 self ._keyframe_index += 1
12431253
1244- extent = self ._model .document .extent
1254+ extent = self .model .document .extent
12451255 if bounds is not None and bounds .extent != extent :
12461256 image = Image .crop (image , bounds )
12471257 image .save (filename )
@@ -1250,10 +1260,10 @@ def _save_frame(self, image: Image, bounds: Bounds):
12501260 def _import_animation (self ):
12511261 if len (self ._keyframes ) == 0 :
12521262 return # button toggled without recording a frame in between
1253- self ._model .document .import_animation (self ._keyframes , self ._keyframe_start )
1263+ self .model .document .import_animation (self ._keyframes , self ._keyframe_start )
12541264 start , end = self ._keyframe_start , self ._keyframe_start + len (self ._keyframes )
1255- prompt = self ._model .regions .active_or_root .positive
1256- self ._model .layers .active .name = f"[Rec] { start } -{ end } : { prompt } "
1265+ prompt = self .model .regions .active_or_root .positive
1266+ self .model .layers .active .name = f"[Rec] { start } -{ end } : { prompt } "
12571267 self ._keyframes = []
12581268
12591269
0 commit comments