@@ -181,11 +181,76 @@ def test_inference(self):
181
181
max_diff = np .abs (image_slice .flatten () - expected_slice ).max ()
182
182
self .assertLessEqual (max_diff , 1e-3 )
183
183
184
+ def test_inference_with_embeddings_and_multiple_images (self ):
185
+ components = self .get_dummy_components ()
186
+ pipe = self .pipeline_class (** components )
187
+ pipe .to (torch_device )
188
+ pipe .set_progress_bar_config (disable = None )
189
+
190
+ inputs = self .get_dummy_inputs (torch_device )
191
+
192
+ prompt = inputs ["prompt" ]
193
+ generator = inputs ["generator" ]
194
+ num_inference_steps = inputs ["num_inference_steps" ]
195
+ output_type = inputs ["output_type" ]
196
+
197
+ prompt_embeds , negative_prompt_embeds = pipe .encode_prompt (prompt )
198
+
199
+ # inputs with prompt converted to embeddings
200
+ inputs = {
201
+ "prompt_embeds" : prompt_embeds ,
202
+ "negative_prompt" : None ,
203
+ "negative_prompt_embeds" : negative_prompt_embeds ,
204
+ "generator" : generator ,
205
+ "num_inference_steps" : num_inference_steps ,
206
+ "output_type" : output_type ,
207
+ "num_images_per_prompt" : 2 ,
208
+ }
209
+
210
+ # set all optional components to None
211
+ for optional_component in pipe ._optional_components :
212
+ setattr (pipe , optional_component , None )
213
+
214
+ output = pipe (** inputs )[0 ]
215
+
216
+ with tempfile .TemporaryDirectory () as tmpdir :
217
+ pipe .save_pretrained (tmpdir )
218
+ pipe_loaded = self .pipeline_class .from_pretrained (tmpdir )
219
+ pipe_loaded .to (torch_device )
220
+ pipe_loaded .set_progress_bar_config (disable = None )
221
+
222
+ for optional_component in pipe ._optional_components :
223
+ self .assertTrue (
224
+ getattr (pipe_loaded , optional_component ) is None ,
225
+ f"`{ optional_component } ` did not stay set to None after loading." ,
226
+ )
227
+
228
+ inputs = self .get_dummy_inputs (torch_device )
229
+
230
+ generator = inputs ["generator" ]
231
+ num_inference_steps = inputs ["num_inference_steps" ]
232
+ output_type = inputs ["output_type" ]
233
+
234
+ # inputs with prompt converted to embeddings
235
+ inputs = {
236
+ "prompt_embeds" : prompt_embeds ,
237
+ "negative_prompt" : None ,
238
+ "negative_prompt_embeds" : negative_prompt_embeds ,
239
+ "generator" : generator ,
240
+ "num_inference_steps" : num_inference_steps ,
241
+ "output_type" : output_type ,
242
+ "num_images_per_prompt" : 2 ,
243
+ }
244
+
245
+ output_loaded = pipe_loaded (** inputs )[0 ]
246
+
247
+ max_diff = np .abs (to_np (output ) - to_np (output_loaded )).max ()
248
+ self .assertLess (max_diff , 1e-4 )
249
+
184
250
def test_inference_batch_single_identical (self ):
185
251
self ._test_inference_batch_single_identical (expected_max_diff = 1e-3 )
186
252
187
253
188
- # TODO: needs to be updated.
189
254
@slow
190
255
@require_torch_gpu
191
256
class PixArtAlphaPipelineIntegrationTests (unittest .TestCase ):
0 commit comments