Skip to content

Commit dd9a5ca

Browse files
sayakpaulslep0vpatrickvonplaten
authored
[Core] support for tiny autoencoder in img2img (#5636)
* support for tiny autoencoder in img2img Co-authored-by: slep0v <[email protected]> * copy fix * line space * line space * clean up * spit out expected value * spit out expected value * assertion values. * assertion values. --------- Co-authored-by: slep0v <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent a35e72b commit dd9a5ca

13 files changed

+180
-22
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@
7575
"""
7676

7777

78+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
79+
def retrieve_latents(encoder_output, generator):
80+
if hasattr(encoder_output, "latent_dist"):
81+
return encoder_output.latent_dist.sample(generator)
82+
elif hasattr(encoder_output, "latents"):
83+
return encoder_output.latents
84+
else:
85+
raise AttributeError("Could not access latents of provided encoder_output")
86+
87+
7888
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
7989
def preprocess(image):
8090
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
@@ -561,11 +571,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
561571

562572
elif isinstance(generator, list):
563573
init_latents = [
564-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
574+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
575+
for i in range(batch_size)
565576
]
566577
init_latents = torch.cat(init_latents, dim=0)
567578
else:
568-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
579+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
569580

570581
init_latents = self.vae.config.scaling_factor * init_latents
571582

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@
9191
"""
9292

9393

94+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
95+
def retrieve_latents(encoder_output, generator):
96+
if hasattr(encoder_output, "latent_dist"):
97+
return encoder_output.latent_dist.sample(generator)
98+
elif hasattr(encoder_output, "latents"):
99+
return encoder_output.latents
100+
else:
101+
raise AttributeError("Could not access latents of provided encoder_output")
102+
103+
94104
def prepare_image(image):
95105
if isinstance(image, torch.Tensor):
96106
# Batch single image
@@ -733,11 +743,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
733743

734744
elif isinstance(generator, list):
735745
init_latents = [
736-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
746+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
747+
for i in range(batch_size)
737748
]
738749
init_latents = torch.cat(init_latents, dim=0)
739750
else:
740-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
751+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
741752

742753
init_latents = self.vae.config.scaling_factor * init_latents
743754

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,16 @@
103103
"""
104104

105105

106+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
107+
def retrieve_latents(encoder_output, generator):
108+
if hasattr(encoder_output, "latent_dist"):
109+
return encoder_output.latent_dist.sample(generator)
110+
elif hasattr(encoder_output, "latents"):
111+
return encoder_output.latents
112+
else:
113+
raise AttributeError("Could not access latents of provided encoder_output")
114+
115+
106116
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
107117
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
108118
"""
@@ -949,12 +959,12 @@ def prepare_mask_latents(
949959
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
950960
if isinstance(generator, list):
951961
image_latents = [
952-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
962+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
953963
for i in range(image.shape[0])
954964
]
955965
image_latents = torch.cat(image_latents, dim=0)
956966
else:
957-
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
967+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
958968

959969
image_latents = self.vae.config.scaling_factor * image_latents
960970

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,16 @@
131131
"""
132132

133133

134+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
135+
def retrieve_latents(encoder_output, generator):
136+
if hasattr(encoder_output, "latent_dist"):
137+
return encoder_output.latent_dist.sample(generator)
138+
elif hasattr(encoder_output, "latents"):
139+
return encoder_output.latents
140+
else:
141+
raise AttributeError("Could not access latents of provided encoder_output")
142+
143+
134144
class StableDiffusionXLControlNetImg2ImgPipeline(
135145
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin
136146
):
@@ -806,11 +816,12 @@ def prepare_latents(
806816

807817
elif isinstance(generator, list):
808818
init_latents = [
809-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
819+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
820+
for i in range(batch_size)
810821
]
811822
init_latents = torch.cat(init_latents, dim=0)
812823
else:
813-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
824+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
814825

815826
if self.vae.config.force_upcast:
816827
self.vae.to(dtype)

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@
4343
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4444

4545

46+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
47+
def retrieve_latents(encoder_output, generator):
48+
if hasattr(encoder_output, "latent_dist"):
49+
return encoder_output.latent_dist.sample(generator)
50+
elif hasattr(encoder_output, "latents"):
51+
return encoder_output.latents
52+
else:
53+
raise AttributeError("Could not access latents of provided encoder_output")
54+
55+
4656
EXAMPLE_DOC_STRING = """
4757
Examples:
4858
```py
@@ -426,11 +436,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
426436

427437
elif isinstance(generator, list):
428438
init_latents = [
429-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
439+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
440+
for i in range(batch_size)
430441
]
431442
init_latents = torch.cat(init_latents, dim=0)
432443
else:
433-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
444+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
434445

435446
init_latents = self.vae.config.scaling_factor * init_latents
436447

src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@
3434
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3535

3636

37+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
38+
def retrieve_latents(encoder_output, generator):
39+
if hasattr(encoder_output, "latent_dist"):
40+
return encoder_output.latent_dist.sample(generator)
41+
elif hasattr(encoder_output, "latents"):
42+
return encoder_output.latents
43+
else:
44+
raise AttributeError("Could not access latents of provided encoder_output")
45+
46+
3747
def prepare_mask_and_masked_image(image, mask):
3848
"""
3949
Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be
@@ -334,12 +344,12 @@ def prepare_mask_latents(
334344
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
335345
if isinstance(generator, list):
336346
image_latents = [
337-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
347+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
338348
for i in range(image.shape[0])
339349
]
340350
image_latents = torch.cat(image_latents, dim=0)
341351
else:
342-
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
352+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
343353

344354
image_latents = self.vae.config.scaling_factor * image_latents
345355

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@
3636
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3737

3838

39+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
40+
def retrieve_latents(encoder_output, generator):
41+
if hasattr(encoder_output, "latent_dist"):
42+
return encoder_output.latent_dist.sample(generator)
43+
elif hasattr(encoder_output, "latents"):
44+
return encoder_output.latents
45+
else:
46+
raise AttributeError("Could not access latents of provided encoder_output")
47+
48+
3949
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
4050
def preprocess(image):
4151
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
@@ -466,11 +476,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
466476

467477
elif isinstance(generator, list):
468478
init_latents = [
469-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
479+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
480+
for i in range(batch_size)
470481
]
471482
init_latents = torch.cat(init_latents, dim=0)
472483
else:
473-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
484+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
474485

475486
init_latents = self.vae.config.scaling_factor * init_latents
476487

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@
7373
"""
7474

7575

76+
def retrieve_latents(encoder_output, generator):
77+
if hasattr(encoder_output, "latent_dist"):
78+
return encoder_output.latent_dist.sample(generator)
79+
elif hasattr(encoder_output, "latents"):
80+
return encoder_output.latents
81+
else:
82+
raise AttributeError("Could not access latents of provided encoder_output")
83+
84+
7685
def preprocess(image):
7786
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
7887
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
@@ -555,11 +564,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
555564

556565
elif isinstance(generator, list):
557566
init_latents = [
558-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
567+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
568+
for i in range(batch_size)
559569
]
560570
init_latents = torch.cat(init_latents, dim=0)
561571
else:
562-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
572+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
563573

564574
init_latents = self.vae.config.scaling_factor * init_latents
565575

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
159159
return mask, masked_image
160160

161161

162+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
163+
def retrieve_latents(encoder_output, generator):
164+
if hasattr(encoder_output, "latent_dist"):
165+
return encoder_output.latent_dist.sample(generator)
166+
elif hasattr(encoder_output, "latents"):
167+
return encoder_output.latents
168+
else:
169+
raise AttributeError("Could not access latents of provided encoder_output")
170+
171+
162172
class StableDiffusionInpaintPipeline(
163173
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
164174
):
@@ -654,12 +664,12 @@ def prepare_latents(
654664
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
655665
if isinstance(generator, list):
656666
image_latents = [
657-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
667+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
658668
for i in range(image.shape[0])
659669
]
660670
image_latents = torch.cat(image_latents, dim=0)
661671
else:
662-
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
672+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
663673

664674
image_latents = self.vae.config.scaling_factor * image_latents
665675

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
9292
return noise_cfg
9393

9494

95+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
96+
def retrieve_latents(encoder_output, generator):
97+
if hasattr(encoder_output, "latent_dist"):
98+
return encoder_output.latent_dist.sample(generator)
99+
elif hasattr(encoder_output, "latents"):
100+
return encoder_output.latents
101+
else:
102+
raise AttributeError("Could not access latents of provided encoder_output")
103+
104+
95105
class StableDiffusionXLImg2ImgPipeline(
96106
DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
97107
):
@@ -604,11 +614,12 @@ def prepare_latents(
604614

605615
elif isinstance(generator, list):
606616
init_latents = [
607-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
617+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
618+
for i in range(batch_size)
608619
]
609620
init_latents = torch.cat(init_latents, dim=0)
610621
else:
611-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
622+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
612623

613624
if self.vae.config.force_upcast:
614625
self.vae.to(dtype)

0 commit comments

Comments
 (0)