Skip to content

Commit 7acc619

Browse files
authored
Custom workflows: overwrite saved workflow & spinner for number parameters (#2216)
* Added override workflow popup when workflow names match. If "no" then add suffix like before. * Added input section on ETN_Parameter value section nodes. These will clamp and sync with krita sliders.
1 parent 57001ce commit 7acc619

File tree

2 files changed

+105
-33
lines changed

2 files changed

+105
-33
lines changed

ai_diffusion/custom_workflow.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,17 @@ def save_as(self, id: str, graph: dict):
170170
self._process_workflow(id, WorkflowSource.local, graph, path)
171171
return id
172172

173+
def overwrite(self, id: str, graph: dict):
174+
existing = self.find(id)
175+
if existing is None or existing.source is not WorkflowSource.local or existing.path is None:
176+
raise KeyError(f"Workflow {id} cannot be overwritten")
177+
178+
path = existing.path
179+
self._folder.mkdir(exist_ok=True)
180+
path.write_text(json.dumps(graph, indent=2))
181+
self._process_workflow(id, WorkflowSource.local, graph, path)
182+
return id
183+
173184
def import_file(self, filepath: Path):
174185
try:
175186
with filepath.open("r") as f:
@@ -431,9 +442,12 @@ def set_graph(self, id: str, graph: dict, document_name: str):
431442
def import_file(self, filepath: Path):
432443
self.workflow_id = self._workflows.import_file(filepath)
433444

434-
def save_as(self, id: str):
445+
def save_as(self, id: str, overwrite: bool = False):
435446
assert self._graph, "Save as: no workflow selected"
436-
self.workflow_id = self._workflows.save_as(id, self._graph.root)
447+
if overwrite:
448+
self.workflow_id = self._workflows.overwrite(id, self._graph.root)
449+
else:
450+
self.workflow_id = self._workflows.save_as(id, self._graph.root)
437451

438452
def remove_workflow(self):
439453
if id := self.workflow_id:
@@ -467,6 +481,10 @@ def workflow(self):
467481
def graph(self):
468482
return self._graph
469483

484+
@property
485+
def workflows(self):
486+
return self._workflows
487+
470488
@property
471489
def metadata(self):
472490
return self._metadata

ai_diffusion/ui/custom_workflow.py

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .widget import TextPromptWidget, WorkspaceSelectWidget, StyleSelectWidget, ErrorBox
2828
from .settings_widgets import ExpanderButton
2929
from . import theme
30+
from .theme import SignalBlocker
3031

3132

3233
class LayerSelect(QComboBox):
@@ -93,31 +94,41 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None):
9394
layout.setContentsMargins(0, 0, 0, 0)
9495
self.setLayout(layout)
9596

97+
self._slider: QSlider | None = None
9698
assert param.min is not None and param.max is not None and param.default is not None
9799
if param.max - param.min <= 200:
98-
self._widget = QSlider(Qt.Orientation.Horizontal, parent)
99-
self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4)
100-
self._widget.valueChanged.connect(self._notify)
101-
self._label = QLabel(self)
102-
self._label.setFixedWidth(QFontMetrics(self.font()).width("000.00") + 4)
103-
self._label.setAlignment(Qt.AlignmentFlag.AlignRight)
100+
self._slider = QSlider(Qt.Orientation.Horizontal, parent)
101+
self._slider.setMinimumHeight(self._slider.minimumSizeHint().height() + 4)
102+
self._slider.valueChanged.connect(self._slider_changed)
103+
self._widget = QSpinBox(parent)
104+
self._widget.valueChanged.connect(self._input_changed)
105+
layout.addWidget(self._slider)
104106
layout.addWidget(self._widget)
105-
layout.addWidget(self._label)
106107
else:
107108
self._widget = QSpinBox(parent)
108109
self._widget.valueChanged.connect(self._notify)
109-
self._label = None
110110
layout.addWidget(self._widget)
111111

112112
min_range = clamp(int(param.min), -(2**31), 2**31 - 1)
113113
max_range = clamp(int(param.max), -(2**31), 2**31 - 1)
114114
self._widget.setRange(min_range, max_range)
115+
if self._slider is not None:
116+
self._slider.setRange(min_range, max_range)
115117

116118
self.value = param.default
117119

120+
def _slider_changed(self, value: int):
121+
with SignalBlocker(self._widget):
122+
self._widget.setValue(value)
123+
self._notify()
124+
125+
def _input_changed(self, value: int):
126+
if self._slider is not None:
127+
with SignalBlocker(self._slider):
128+
self._slider.setValue(value)
129+
self._notify()
130+
118131
def _notify(self):
119-
if self._label:
120-
self._label.setText(str(self._widget.value()))
121132
self.value_changed.emit()
122133

123134
@property
@@ -126,7 +137,13 @@ def value(self):
126137

127138
@value.setter
128139
def value(self, value: int | float):
129-
self._widget.setValue(int(value))
140+
v = int(value)
141+
v = max(self._widget.minimum(), min(self._widget.maximum(), v))
142+
with SignalBlocker(self._widget):
143+
self._widget.setValue(v)
144+
if self._slider is not None:
145+
with SignalBlocker(self._slider):
146+
self._slider.setValue(v)
130147

131148

132149
class FloatParamWidget(QWidget):
@@ -141,44 +158,55 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None):
141158
layout.setContentsMargins(0, 0, 0, 0)
142159
self.setLayout(layout)
143160

161+
self._slider: QSlider | None = None
144162
assert param.min is not None and param.max is not None and param.default is not None
145163
if param.max - param.min <= 100:
146-
self._widget = QSlider(Qt.Orientation.Horizontal, parent)
147-
self._widget.setRange(round(param.min * 100), round(param.max * 100))
148-
self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4)
149-
self._widget.valueChanged.connect(self._notify)
150-
self._label = QLabel(self)
151-
self._label.setFixedWidth(QFontMetrics(self.font()).width("000.00") + 4)
152-
self._label.setAlignment(Qt.AlignmentFlag.AlignRight)
164+
self._slider = QSlider(Qt.Orientation.Horizontal, parent)
165+
self._slider.setRange(round(param.min * 100), round(param.max * 100))
166+
self._slider.setMinimumHeight(self._slider.minimumSizeHint().height() + 4)
167+
self._slider.valueChanged.connect(self._slider_changed)
168+
self._widget = QDoubleSpinBox(parent)
169+
self._widget.setRange(param.min, param.max)
170+
self._widget.setDecimals(2)
171+
self._widget.valueChanged.connect(self._input_changed)
172+
layout.addWidget(self._slider)
153173
layout.addWidget(self._widget)
154-
layout.addWidget(self._label)
155174
else:
156175
self._widget = QDoubleSpinBox(parent)
157176
self._widget.setRange(param.min, param.max)
158177
self._widget.valueChanged.connect(self._notify)
159-
self._label = None
160178
layout.addWidget(self._widget)
161179

162180
self.value = param.default
163181

182+
def _slider_changed(self, value: int):
183+
v = value / 100.0
184+
with SignalBlocker(self._widget):
185+
self._widget.setValue(v)
186+
self._notify()
187+
188+
def _input_changed(self, value: float):
189+
if self._slider is not None:
190+
with SignalBlocker(self._slider):
191+
self._slider.setValue(round(value * 100))
192+
self._notify()
193+
164194
def _notify(self):
165-
if self._label:
166-
self._label.setText(f"{self.value:.2f}")
167195
self.value_changed.emit()
168196

169197
@property
170198
def value(self):
171-
if isinstance(self._widget, QSlider):
172-
return self._widget.value() / 100
173-
else:
174-
return self._widget.value()
199+
return float(self._widget.value())
175200

176201
@value.setter
177202
def value(self, value: float | int):
178-
if isinstance(self._widget, QSlider):
179-
self._widget.setValue(round(value * 100))
180-
else:
181-
self._widget.setValue(float(value))
203+
v = float(value)
204+
v = max(self._widget.minimum(), min(self._widget.maximum(), v))
205+
with SignalBlocker(self._widget):
206+
self._widget.setValue(v)
207+
if self._slider is not None:
208+
with SignalBlocker(self._slider):
209+
self._slider.setValue(round(v * 100))
182210

183211

184212
class BoolParamWidget(QWidget):
@@ -943,7 +971,33 @@ def _edit_name(self):
943971

944972
@popup_on_error
945973
def _accept_name(self, *args):
946-
self.model.custom.save_as(self._workflow_name_edit.text())
974+
name = self._workflow_name_edit.text().strip()
975+
workspace = self.model.custom
976+
overwrite = False
977+
978+
current = workspace.workflow
979+
existing = workspace.workflows.find(name)
980+
if (
981+
current is not None
982+
and current.source is WorkflowSource.remote
983+
and existing is not None
984+
and existing.source is WorkflowSource.local
985+
):
986+
details = f"\n{existing.path}" if existing.path is not None else ""
987+
q = QMessageBox.question(
988+
self,
989+
_("Overwrite Workflow"),
990+
_("A workflow named '{name}' already exists. Do you want to overwrite it?").format(
991+
name=name
992+
)
993+
+ details,
994+
QMessageBox.Yes | QMessageBox.No,
995+
QMessageBox.StandardButton.No,
996+
)
997+
if q == QMessageBox.StandardButton.Yes:
998+
overwrite = True
999+
1000+
workspace.save_as(name, overwrite=overwrite)
9471001
self.is_edit_mode = False
9481002

9491003
def _cancel_name(self):

0 commit comments

Comments
 (0)