Skip to content

Commit 9896f0d

Browse files
jeffrey-dot-lizaristei
authored andcommitted
Fix: explicit not none check for tensors in flash attention (huggingface#39639)
fix: explicit not none check for tensors
1 parent ccca85e commit 9896f0d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _flash_attention_forward(
398398
query_states, key_states, value_states = fa_peft_integration_check(
399399
query_states, key_states, value_states, target_dtype
400400
)
401-
use_mask = position_ids is not None or all([cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])
401+
use_mask = position_ids is not None or all(k is not None for k in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k])
402402
if attention_mask is not None:
403403
q, k, v, idx, (cu_q, cu_k), (mq, mk) = _upad_input(
404404
query_states, key_states, value_states, attention_mask, query_length, unpad_fn

0 commit comments

Comments
 (0)