Skip to content

Commit 8959a73

Browse files
committed
Don't zero out conditioning for SD1.5
- it does not like it.
1 parent 6a80e3b commit 8959a73

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

ai_diffusion/workflow.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
153153
if arch.supports_attention_guidance and checkpoint.self_attention_guidance:
154154
model = w.apply_self_attention_guidance(model)
155155

156+
return model, Clip(clip, arch), vae
157+
156158

157159
def vae_decode(w: ComfyWorkflow, vae: Output, latent: Output, tiled: bool):
158160
if tiled:
@@ -253,26 +255,31 @@ def from_input(i: ControlInput):
253255
return Control(i.mode, ImageOutput(i.image), None, i.strength, i.range)
254256

255257

258+
class Clip(NamedTuple):
259+
model: Output
260+
arch: Arch
261+
262+
256263
class TextPrompt:
257264
text: str
258265
language: str
259266
# Cached values to avoid re-encoding the same text for multiple regions and passes
260267
_output: Output | None = None
261-
_clip: Output | None = None # can be different due to Lora hooks
268+
_clip: Clip | None = None # can be different due to Lora hooks
262269

263270
def __init__(self, text: str, language: str):
264271
self.text = text
265272
self.language = language
266273

267-
def encode(self, w: ComfyWorkflow, clip: Output, style_prompt: str | None = None):
274+
def encode(self, w: ComfyWorkflow, clip: Clip, style_prompt: str | None = None):
268275
text = self.text
269276
if text != "" and style_prompt:
270277
text = merge_prompt(text, style_prompt, self.language)
271278
if self._output is None or self._clip != clip:
272279
if text and self.language:
273280
text = w.translate(text)
274-
self._output = w.clip_text_encode(clip, text)
275-
if text == "":
281+
self._output = w.clip_text_encode(clip.model, text)
282+
if text == "" and clip.arch is not Arch.sd15:
276283
self._output = w.conditioning_zero_out(self._output)
277284
self._clip = clip
278285
return self._output
@@ -286,7 +293,7 @@ class Region:
286293
control: list[Control] = field(default_factory=list)
287294
loras: list[LoraInput] = field(default_factory=list)
288295
is_background: bool = False
289-
clip: Output | None = None
296+
clip: Clip | None = None
290297

291298
@staticmethod
292299
def from_input(i: RegionInput, index: int, language: str):
@@ -301,15 +308,15 @@ def from_input(i: RegionInput, index: int, language: str):
301308
is_background=index == 0,
302309
)
303310

304-
def patch_clip(self, w: ComfyWorkflow, clip: Output):
311+
def patch_clip(self, w: ComfyWorkflow, clip: Clip):
305312
if self.clip is None:
306313
self.clip = clip
307314
if len(self.loras) > 0:
308315
hooks = w.create_hook_lora([(lora.name, lora.strength) for lora in self.loras])
309-
self.clip = w.set_clip_hooks(clip, hooks)
316+
self.clip = Clip(w.set_clip_hooks(clip.model, hooks), clip.arch)
310317
return self.clip
311318

312-
def encode_prompt(self, w: ComfyWorkflow, clip: Output, style_prompt: str | None = None):
319+
def encode_prompt(self, w: ComfyWorkflow, clip: Clip, style_prompt: str | None = None):
313320
return self.positive.encode(w, self.patch_clip(w, clip), style_prompt)
314321

315322
def copy(self):
@@ -384,7 +391,7 @@ def downscale_all_control_images(cond: ConditioningInput, original: Extent, targ
384391
def encode_text_prompt(
385392
w: ComfyWorkflow,
386393
cond: Conditioning,
387-
clip: Output,
394+
clip: Clip,
388395
regions: Output | None,
389396
):
390397
if len(cond.regions) <= 1 or all(len(r.loras) == 0 for r in cond.regions):
@@ -413,7 +420,7 @@ def apply_attention_mask(
413420
w: ComfyWorkflow,
414421
model: Output,
415422
cond: Conditioning,
416-
clip: Output,
423+
clip: Clip,
417424
shape: Extent | ImageReshape = no_reshape,
418425
):
419426
if len(cond.regions) == 0:
@@ -643,7 +650,7 @@ def scale_refine_and_decode(
643650
sampling: SamplingInput,
644651
latent: Output,
645652
model: Output,
646-
clip: Output,
653+
clip: Clip,
647654
vae: Output,
648655
models: ModelDict,
649656
tiled_vae: bool,
@@ -1240,7 +1247,7 @@ def get_param(node: ComfyNode, expected_type: type | tuple[type, type] | None =
12401247
sampling = _sampling_from_style(style, 1.0, is_live)
12411248
model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, models)
12421249
outputs[node.output(0)] = model
1243-
outputs[node.output(1)] = clip
1250+
outputs[node.output(1)] = clip.model
12441251
outputs[node.output(2)] = vae
12451252
outputs[node.output(3)] = style.style_prompt
12461253
outputs[node.output(4)] = style.negative_prompt

0 commit comments

Comments
 (0)