Skip to content

Commit fa4c0e5

Browse files
authored
optimize QwenImagePipeline to reduce unnecessary CUDA synchronization (#12072)
1 parent b793deb commit fa4c0e5

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,11 @@ def __call__(
636636
if self.attention_kwargs is None:
637637
self._attention_kwargs = {}
638638

639+
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
640+
negative_txt_seq_lens = (
641+
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
642+
)
643+
639644
# 6. Denoising loop
640645
self.scheduler.set_begin_index(0)
641646
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -654,7 +659,7 @@ def __call__(
654659
encoder_hidden_states_mask=prompt_embeds_mask,
655660
encoder_hidden_states=prompt_embeds,
656661
img_shapes=img_shapes,
657-
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
662+
txt_seq_lens=txt_seq_lens,
658663
attention_kwargs=self.attention_kwargs,
659664
return_dict=False,
660665
)[0]
@@ -668,7 +673,7 @@ def __call__(
668673
encoder_hidden_states_mask=negative_prompt_embeds_mask,
669674
encoder_hidden_states=negative_prompt_embeds,
670675
img_shapes=img_shapes,
671-
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
676+
txt_seq_lens=negative_txt_seq_lens,
672677
attention_kwargs=self.attention_kwargs,
673678
return_dict=False,
674679
)[0]

0 commit comments

Comments
 (0)