Skip to content

Commit 0ac7d39

Browse files
[PixArt-Alpha] Support non-square images (#5672)
* debug * support non-square images * add: test * fix: test --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent d190959 commit 0ac7d39

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/diffusers/models/transformer_2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def forward(
339339
elif self.is_input_vectorized:
340340
hidden_states = self.latent_image_embedding(hidden_states)
341341
elif self.is_input_patches:
342+
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
342343
hidden_states = self.pos_embed(hidden_states)
343344

344345
if self.adaln_single is not None:
@@ -425,7 +426,8 @@ def forward(
425426
hidden_states = hidden_states.squeeze(1)
426427

427428
# unpatchify
428-
height = width = int(hidden_states.shape[1] ** 0.5)
429+
if self.adaln_single is None:
430+
height = width = int(hidden_states.shape[1] ** 0.5)
429431
hidden_states = hidden_states.reshape(
430432
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
431433
)

tests/pipelines/pixart/test_pixart.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,29 @@ def test_inference(self):
174174
inputs = self.get_dummy_inputs(device)
175175
image = pipe(**inputs).images
176176
image_slice = image[0, -3:, -3:, -1]
177-
print(torch.from_numpy(image_slice.flatten()))
178177

179178
self.assertEqual(image.shape, (1, 8, 8, 3))
180179
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
181180
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
182181
self.assertLessEqual(max_diff, 1e-3)
183182

183+
def test_inference_non_square_images(self):
184+
device = "cpu"
185+
186+
components = self.get_dummy_components()
187+
pipe = self.pipeline_class(**components)
188+
pipe.to(device)
189+
pipe.set_progress_bar_config(disable=None)
190+
191+
inputs = self.get_dummy_inputs(device)
192+
image = pipe(**inputs, height=32, width=48).images
193+
image_slice = image[0, -3:, -3:, -1]
194+
195+
self.assertEqual(image.shape, (1, 32, 48, 3))
196+
expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
197+
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
198+
self.assertLessEqual(max_diff, 1e-3)
199+
184200
def test_inference_with_embeddings_and_multiple_images(self):
185201
components = self.get_dummy_components()
186202
pipe = self.pipeline_class(**components)

0 commit comments

Comments
 (0)