@@ -636,6 +636,11 @@ def __call__(
636
636
if self .attention_kwargs is None :
637
637
self ._attention_kwargs = {}
638
638
639
+ txt_seq_lens = prompt_embeds_mask .sum (dim = 1 ).tolist () if prompt_embeds_mask is not None else None
640
+ negative_txt_seq_lens = (
641
+ negative_prompt_embeds_mask .sum (dim = 1 ).tolist () if negative_prompt_embeds_mask is not None else None
642
+ )
643
+
639
644
# 6. Denoising loop
640
645
self .scheduler .set_begin_index (0 )
641
646
with self .progress_bar (total = num_inference_steps ) as progress_bar :
@@ -654,7 +659,7 @@ def __call__(
654
659
encoder_hidden_states_mask = prompt_embeds_mask ,
655
660
encoder_hidden_states = prompt_embeds ,
656
661
img_shapes = img_shapes ,
657
- txt_seq_lens = prompt_embeds_mask . sum ( dim = 1 ). tolist () ,
662
+ txt_seq_lens = txt_seq_lens ,
658
663
attention_kwargs = self .attention_kwargs ,
659
664
return_dict = False ,
660
665
)[0 ]
@@ -668,7 +673,7 @@ def __call__(
668
673
encoder_hidden_states_mask = negative_prompt_embeds_mask ,
669
674
encoder_hidden_states = negative_prompt_embeds ,
670
675
img_shapes = img_shapes ,
671
- txt_seq_lens = negative_prompt_embeds_mask . sum ( dim = 1 ). tolist () ,
676
+ txt_seq_lens = negative_txt_seq_lens ,
672
677
attention_kwargs = self .attention_kwargs ,
673
678
return_dict = False ,
674
679
)[0 ]
0 commit comments