Skip to content

Commit 4d9b822

Browse files
authored
[qwen] Qwen image edit followups (#12166)
* add docs. * more docs. * xfail full compilation for Qwen for now. * tests * up * up * up * reviewer feedback.
1 parent 76c809e commit 4d9b822

File tree

7 files changed

+280
-13
lines changed

7 files changed

+280
-13
lines changed

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
1818

19-
Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
19+
Qwen-Image comes in the following variants:
20+
21+
| model type | model id |
22+
|:----------:|:--------:|
23+
| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
24+
| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
2025

2126
<Tip>
2227

@@ -87,10 +92,6 @@ image.save("qwen_fewsteps.png")
8792
- all
8893
- __call__
8994

90-
## QwenImagePipelineOutput
91-
92-
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
93-
9495
## QwenImageImg2ImgPipeline
9596

9697
[[autodoc]] QwenImageImg2ImgPipeline
@@ -102,3 +103,13 @@ image.save("qwen_fewsteps.png")
102103
[[autodoc]] QwenImageInpaintPipeline
103104
- all
104105
- __call__
106+
107+
## QwenImageEditPipeline
108+
109+
[[autodoc]] QwenImageEditPipeline
110+
- all
111+
- __call__
112+
113+
## QwenImagePipelineOutput
114+
115+
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,10 @@
489489
"PixArtAlphaPipeline",
490490
"PixArtSigmaPAGPipeline",
491491
"PixArtSigmaPipeline",
492+
"QwenImageEditPipeline",
492493
"QwenImageImg2ImgPipeline",
493494
"QwenImageInpaintPipeline",
494495
"QwenImagePipeline",
495-
"QwenImageEditPipeline",
496496
"ReduxImageEncoder",
497497
"SanaControlNetPipeline",
498498
"SanaPAGPipeline",

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
219219
video_freq = self.rope_cache[rope_key]
220220
else:
221221
video_freq = self._compute_video_freqs(frame, height, width, idx)
222+
video_freq = video_freq.to(device)
222223
vid_freqs.append(video_freq)
223224

224225
if self.scale_rope:

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
else:
2525
_import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
2626
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
27+
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
2728
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
2829
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
29-
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
3030

3131
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3232
try:

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,20 @@
4646
>>> import torch
4747
>>> from PIL import Image
4848
>>> from diffusers import QwenImageEditPipeline
49+
>>> from diffusers.utils import load_image
4950
5051
>>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
5152
>>> pipe.to("cuda")
52-
>>> prompt = "Change the cat to a dog"
53-
>>> image = Image.open("cat.png")
53+
>>> image = load_image(
54+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
55+
... ).convert("RGB")
56+
>>> prompt = (
57+
... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
58+
... )
5459
>>> # Depending on the variant being used, the pipeline call will slightly vary.
5560
>>> # Refer to the pipeline documentation for more details.
5661
>>> image = pipe(image, prompt, num_inference_steps=50).images[0]
57-
>>> image.save("qwenimageedit.png")
62+
>>> image.save("qwenimage_edit.png")
5863
```
5964
"""
6065
PREFERRED_QWENIMAGE_RESOLUTIONS = [
@@ -178,7 +183,7 @@ def calculate_dimensions(target_area, ratio):
178183

179184
class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
180185
r"""
181-
The QwenImage pipeline for text-to-image generation.
186+
The Qwen-Image-Edit pipeline for image editing.
182187
183188
Args:
184189
transformer ([`QwenImageTransformer2DModel`]):
@@ -217,8 +222,8 @@ def __init__(
217222
transformer=transformer,
218223
scheduler=scheduler,
219224
)
220-
self.latent_channels = 16
221225
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
226+
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
222227
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
223228
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
224229
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
@@ -635,7 +640,9 @@ def __call__(
635640
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
636641
returning a tuple, the first element is a list with the generated images.
637642
"""
638-
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image.width / image.height)
643+
image_size = image[0].size if isinstance(image, list) else image.size
644+
width, height = image_size
645+
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
639646
height = height or calculated_height
640647
width = width or calculated_width
641648

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import unittest
1717

18+
import pytest
1819
import torch
1920

2021
from diffusers import QwenImageTransformer2DModel
@@ -99,3 +100,7 @@ def prepare_init_args_and_inputs_for_common(self):
99100

100101
def prepare_dummy_input(self, height, width):
101102
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
103+
104+
@pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
105+
def test_torch_compile_recompilation_and_graph_break(self):
106+
super().test_torch_compile_recompilation_and_graph_break()
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright 2025 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import pytest
19+
import torch
20+
from PIL import Image
21+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
22+
23+
from diffusers import (
24+
AutoencoderKLQwenImage,
25+
FlowMatchEulerDiscreteScheduler,
26+
QwenImageEditPipeline,
27+
QwenImageTransformer2DModel,
28+
)
29+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
30+
31+
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
32+
from ..test_pipelines_common import PipelineTesterMixin, to_np
33+
34+
35+
enable_full_determinism()
36+
37+
38+
class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
39+
pipeline_class = QwenImageEditPipeline
40+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
41+
batch_params = frozenset(["prompt", "image"])
42+
image_params = frozenset(["image"])
43+
image_latents_params = frozenset(["latents"])
44+
required_optional_params = frozenset(
45+
[
46+
"num_inference_steps",
47+
"generator",
48+
"latents",
49+
"return_dict",
50+
"callback_on_step_end",
51+
"callback_on_step_end_tensor_inputs",
52+
]
53+
)
54+
supports_dduf = False
55+
test_xformers_attention = False
56+
test_layerwise_casting = True
57+
test_group_offloading = True
58+
59+
def get_dummy_components(self):
60+
tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
61+
62+
torch.manual_seed(0)
63+
transformer = QwenImageTransformer2DModel(
64+
patch_size=2,
65+
in_channels=16,
66+
out_channels=4,
67+
num_layers=2,
68+
attention_head_dim=16,
69+
num_attention_heads=3,
70+
joint_attention_dim=16,
71+
guidance_embeds=False,
72+
axes_dims_rope=(8, 4, 4),
73+
)
74+
75+
torch.manual_seed(0)
76+
z_dim = 4
77+
vae = AutoencoderKLQwenImage(
78+
base_dim=z_dim * 6,
79+
z_dim=z_dim,
80+
dim_mult=[1, 2, 4],
81+
num_res_blocks=1,
82+
temperal_downsample=[False, True],
83+
latents_mean=[0.0] * z_dim,
84+
latents_std=[1.0] * z_dim,
85+
)
86+
87+
torch.manual_seed(0)
88+
scheduler = FlowMatchEulerDiscreteScheduler()
89+
90+
torch.manual_seed(0)
91+
config = Qwen2_5_VLConfig(
92+
text_config={
93+
"hidden_size": 16,
94+
"intermediate_size": 16,
95+
"num_hidden_layers": 2,
96+
"num_attention_heads": 2,
97+
"num_key_value_heads": 2,
98+
"rope_scaling": {
99+
"mrope_section": [1, 1, 2],
100+
"rope_type": "default",
101+
"type": "default",
102+
},
103+
"rope_theta": 1000000.0,
104+
},
105+
vision_config={
106+
"depth": 2,
107+
"hidden_size": 16,
108+
"intermediate_size": 16,
109+
"num_heads": 2,
110+
"out_hidden_size": 16,
111+
},
112+
hidden_size=16,
113+
vocab_size=152064,
114+
vision_end_token_id=151653,
115+
vision_start_token_id=151652,
116+
vision_token_id=151654,
117+
)
118+
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
119+
tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
120+
121+
components = {
122+
"transformer": transformer,
123+
"vae": vae,
124+
"scheduler": scheduler,
125+
"text_encoder": text_encoder,
126+
"tokenizer": tokenizer,
127+
"processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
128+
}
129+
return components
130+
131+
def get_dummy_inputs(self, device, seed=0):
132+
if str(device).startswith("mps"):
133+
generator = torch.manual_seed(seed)
134+
else:
135+
generator = torch.Generator(device=device).manual_seed(seed)
136+
137+
inputs = {
138+
"prompt": "dance monkey",
139+
"image": Image.new("RGB", (32, 32)),
140+
"negative_prompt": "bad quality",
141+
"generator": generator,
142+
"num_inference_steps": 2,
143+
"true_cfg_scale": 1.0,
144+
"height": 32,
145+
"width": 32,
146+
"max_sequence_length": 16,
147+
"output_type": "pt",
148+
}
149+
150+
return inputs
151+
152+
def test_inference(self):
153+
device = "cpu"
154+
155+
components = self.get_dummy_components()
156+
pipe = self.pipeline_class(**components)
157+
pipe.to(device)
158+
pipe.set_progress_bar_config(disable=None)
159+
160+
inputs = self.get_dummy_inputs(device)
161+
image = pipe(**inputs).images
162+
generated_image = image[0]
163+
self.assertEqual(generated_image.shape, (3, 32, 32))
164+
165+
# fmt: off
166+
expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
167+
# fmt: on
168+
169+
generated_slice = generated_image.flatten()
170+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
171+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
172+
173+
def test_inference_batch_single_identical(self):
174+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
175+
176+
def test_attention_slicing_forward_pass(
177+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
178+
):
179+
if not self.test_attention_slicing:
180+
return
181+
182+
components = self.get_dummy_components()
183+
pipe = self.pipeline_class(**components)
184+
for component in pipe.components.values():
185+
if hasattr(component, "set_default_attn_processor"):
186+
component.set_default_attn_processor()
187+
pipe.to(torch_device)
188+
pipe.set_progress_bar_config(disable=None)
189+
190+
generator_device = "cpu"
191+
inputs = self.get_dummy_inputs(generator_device)
192+
output_without_slicing = pipe(**inputs)[0]
193+
194+
pipe.enable_attention_slicing(slice_size=1)
195+
inputs = self.get_dummy_inputs(generator_device)
196+
output_with_slicing1 = pipe(**inputs)[0]
197+
198+
pipe.enable_attention_slicing(slice_size=2)
199+
inputs = self.get_dummy_inputs(generator_device)
200+
output_with_slicing2 = pipe(**inputs)[0]
201+
202+
if test_max_difference:
203+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
204+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
205+
self.assertLess(
206+
max(max_diff1, max_diff2),
207+
expected_max_diff,
208+
"Attention slicing should not affect the inference results",
209+
)
210+
211+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
212+
generator_device = "cpu"
213+
components = self.get_dummy_components()
214+
215+
pipe = self.pipeline_class(**components)
216+
pipe.to("cpu")
217+
pipe.set_progress_bar_config(disable=None)
218+
219+
# Without tiling
220+
inputs = self.get_dummy_inputs(generator_device)
221+
inputs["height"] = inputs["width"] = 128
222+
output_without_tiling = pipe(**inputs)[0]
223+
224+
# With tiling
225+
pipe.vae.enable_tiling(
226+
tile_sample_min_height=96,
227+
tile_sample_min_width=96,
228+
tile_sample_stride_height=64,
229+
tile_sample_stride_width=64,
230+
)
231+
inputs = self.get_dummy_inputs(generator_device)
232+
inputs["height"] = inputs["width"] = 128
233+
output_with_tiling = pipe(**inputs)[0]
234+
235+
self.assertLess(
236+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
237+
expected_diff_max,
238+
"VAE tiling should not affect the inference results",
239+
)
240+
241+
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
242+
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
243+
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)

0 commit comments

Comments
 (0)