Skip to content

Commit d5ff8f8

Browse files
sayakpaulpatrickvonplaten
authored andcommitted
[PixArt-Alpha] fix mask_feature so that precomputed embeddings work with a batch size > 1 (#5677)
* fix embeds * remove todo * add: test * better name
1 parent b4ca05f commit d5ff8f8

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def encode_prompt(
253253
negative_prompt_embeds = None
254254

255255
# Perform additional masking.
256-
if mask_feature:
256+
if mask_feature and prompt_embeds is None and negative_prompt_embeds is None:
257257
prompt_embeds = prompt_embeds.unsqueeze(1)
258258
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
259259
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)

tests/pipelines/pixart/test_pixart.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,76 @@ def test_inference(self):
181181
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
182182
self.assertLessEqual(max_diff, 1e-3)
183183

184+
def test_inference_with_embeddings_and_multiple_images(self):
185+
components = self.get_dummy_components()
186+
pipe = self.pipeline_class(**components)
187+
pipe.to(torch_device)
188+
pipe.set_progress_bar_config(disable=None)
189+
190+
inputs = self.get_dummy_inputs(torch_device)
191+
192+
prompt = inputs["prompt"]
193+
generator = inputs["generator"]
194+
num_inference_steps = inputs["num_inference_steps"]
195+
output_type = inputs["output_type"]
196+
197+
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
198+
199+
# inputs with prompt converted to embeddings
200+
inputs = {
201+
"prompt_embeds": prompt_embeds,
202+
"negative_prompt": None,
203+
"negative_prompt_embeds": negative_prompt_embeds,
204+
"generator": generator,
205+
"num_inference_steps": num_inference_steps,
206+
"output_type": output_type,
207+
"num_images_per_prompt": 2,
208+
}
209+
210+
# set all optional components to None
211+
for optional_component in pipe._optional_components:
212+
setattr(pipe, optional_component, None)
213+
214+
output = pipe(**inputs)[0]
215+
216+
with tempfile.TemporaryDirectory() as tmpdir:
217+
pipe.save_pretrained(tmpdir)
218+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
219+
pipe_loaded.to(torch_device)
220+
pipe_loaded.set_progress_bar_config(disable=None)
221+
222+
for optional_component in pipe._optional_components:
223+
self.assertTrue(
224+
getattr(pipe_loaded, optional_component) is None,
225+
f"`{optional_component}` did not stay set to None after loading.",
226+
)
227+
228+
inputs = self.get_dummy_inputs(torch_device)
229+
230+
generator = inputs["generator"]
231+
num_inference_steps = inputs["num_inference_steps"]
232+
output_type = inputs["output_type"]
233+
234+
# inputs with prompt converted to embeddings
235+
inputs = {
236+
"prompt_embeds": prompt_embeds,
237+
"negative_prompt": None,
238+
"negative_prompt_embeds": negative_prompt_embeds,
239+
"generator": generator,
240+
"num_inference_steps": num_inference_steps,
241+
"output_type": output_type,
242+
"num_images_per_prompt": 2,
243+
}
244+
245+
output_loaded = pipe_loaded(**inputs)[0]
246+
247+
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
248+
self.assertLess(max_diff, 1e-4)
249+
184250
def test_inference_batch_single_identical(self):
185251
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
186252

187253

188-
# TODO: needs to be updated.
189254
@slow
190255
@require_torch_gpu
191256
class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)