@@ -132,7 +132,12 @@ def qwen2_model_forward(
132
132
else :
133
133
position_ids = position_ids .view (- 1 , seq_length ).long ()
134
134
135
- if attention_mask is not None and self ._attn_implementation == "flash_attention_2" and use_cache :
135
+ if (
136
+ not shard_config .enable_flash_attention
137
+ and attention_mask is not None
138
+ and self ._attn_implementation == "flash_attention_2"
139
+ and use_cache
140
+ ):
136
141
is_padding_right = attention_mask [:, - 1 ].sum ().item () != batch_size
137
142
if is_padding_right :
138
143
raise ValueError (
@@ -144,7 +149,6 @@ def qwen2_model_forward(
144
149
# for the other stages, hidden_states is the output of the previous stage
145
150
if shard_config .enable_flash_attention :
146
151
# in this case, attention_mask is a dict rather than a tensor
147
- (batch_size , 1 , seq_length , seq_length_with_past )
148
152
attention_mask = None
149
153
else :
150
154
if self ._attn_implementation == "flash_attention_2" :
@@ -616,7 +620,7 @@ def forward(
616
620
617
621
attn_output = self .o_proj (attn_output )
618
622
619
- return attn_output , None , past_key_value
623
+ return attn_output , None
620
624
621
625
return forward
622
626
@@ -805,15 +809,7 @@ def forward(
805
809
hidden_states = inputs_embeds
806
810
807
811
if shard_config .enable_flash_attention :
808
- # in this case, attention_mask is a dict rather than a tensor
809
- mask_shape = (batch_size , 1 , seq_length , seq_length_with_past )
810
- attention_mask = ColoAttention .prepare_attn_kwargs (
811
- mask_shape ,
812
- hidden_states .dtype ,
813
- hidden_states .device ,
814
- q_padding_mask = attention_mask ,
815
- is_causal = True ,
816
- )
812
+ attention_mask = None
817
813
else :
818
814
attention_mask = _prepare_4d_causal_attention_mask (
819
815
attention_mask ,
0 commit comments