Skip to content

Commit 2a82cf0

Browse files
authored
make fixup (#39661)
1 parent e376050 commit 2a82cf0

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,9 @@ def _flash_attention_forward(
393393
query_states, key_states, value_states = fa_peft_integration_check(
394394
query_states, key_states, value_states, target_dtype
395395
)
396-
use_mask = position_ids is not None or all(k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])
396+
use_mask = position_ids is not None or all(
397+
k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]
398+
)
397399
if attention_mask is not None:
398400
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
399401
query_states, key_states, value_states, attention_mask, query_length, unpad_fn

0 commit comments

Comments
 (0)