Skip to content

Commit 93f2984

Browse files
committed
ci fix
1 parent 18dbd84 commit 93f2984

32 files changed

+62
-26
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131

3232
logger = logging.get_logger(__name__)
33+
flash_attn_func = None
3334

3435

3536
def _index_first_axis(tensor, indices):
@@ -92,6 +93,7 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
9293
output[indices] = hidden_states
9394
return output.view(batch, seqlen, *dim)
9495

96+
9597
FA_VERSION = None
9698
if is_flash_attn_2_available():
9799
from flash_attn import flash_attn_func as flash_attn_2_func
@@ -135,10 +137,19 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
135137

136138
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
137139
if is_torch_npu_available():
138-
from .integrations.npu_flash_attention import pad_input, unpad_input
139-
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
140-
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
141-
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
140+
from .integrations.npu_flash_attention import (
141+
npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401
142+
)
143+
from .integrations.npu_flash_attention import (
144+
npu_flash_attn_func as flash_attn_func,
145+
)
146+
from .integrations.npu_flash_attention import (
147+
npu_flash_attn_varlen_func as flash_attn_varlen_func,
148+
)
149+
from .integrations.npu_flash_attention import (
150+
pad_input,
151+
unpad_input,
152+
)
142153

143154

144155
_flash_supports_window_size = False
@@ -279,9 +290,7 @@ def _upad_input(
279290
else:
280291
# The -q_len: slice assumes left padding.
281292
attention_mask = attention_mask[:, -query_length:]
282-
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(
283-
query_layer, attention_mask
284-
)
293+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
285294

286295
return (
287296
query_layer,

src/transformers/modeling_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2549,7 +2549,7 @@ def _check_and_enable_flash_attn_3(
25492549
else:
25502550
raise ValueError(
25512551
f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device."
2552-
)
2552+
)
25532553

25542554
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
25552555

@@ -2570,9 +2570,7 @@ def _check_and_enable_flash_attn_3(
25702570
)
25712571

25722572
if getattr(config, "alibi", False) or getattr(config, "use_alibi", False):
2573-
raise ValueError(
2574-
"Model is configured to use ALiBi, which is not supported by Flash Attention 3."
2575-
)
2573+
raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
25762574

25772575
# Check for attention dropout, which is incompatible with FA3
25782576
if hasattr(config, "attention_dropout") and config.attention_dropout > 0:

src/transformers/models/aria/modeling_aria.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,7 @@ class AriaPreTrainedModel(PreTrainedModel):
667667
supports_gradient_checkpointing = True
668668
_no_split_modules = ["AriaDecoderLayer"]
669669
_skip_keys_device_placement = ["past_key_values"]
670+
_supports_flash_attn_3 = True
670671
_supports_flash_attn_2 = True
671672
_supports_sdpa = True
672673
_supports_flex_attn = True

src/transformers/models/bitnet/modeling_bitnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ class BitNetPreTrainedModel(PreTrainedModel):
318318
supports_gradient_checkpointing = True
319319
_no_split_modules = ["BitNetDecoderLayer"]
320320
_skip_keys_device_placement = ["past_key_values"]
321+
_supports_flash_attn_3 = True
321322
_supports_flash_attn_2 = True
322323
_supports_sdpa = True
323324
_supports_flex_attn = True

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ class CoherePreTrainedModel(PreTrainedModel):
355355
supports_gradient_checkpointing = True
356356
_no_split_modules = ["CohereDecoderLayer"]
357357
_skip_keys_device_placement = ["past_key_values"]
358+
_supports_flash_attn_3 = True
358359
_supports_flash_attn_2 = True
359360
_supports_sdpa = True
360361
_supports_flex_attn = True

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ class Cohere2PreTrainedModel(PreTrainedModel):
334334
supports_gradient_checkpointing = True
335335
_no_split_modules = ["Cohere2DecoderLayer"]
336336
_skip_keys_device_placement = ["past_key_values"]
337+
_supports_flash_attn_3 = True
337338
_supports_flash_attn_2 = True
338339
_supports_sdpa = True
339340
_supports_flex_attn = True

src/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
504504
supports_gradient_checkpointing = True
505505
_no_split_modules = ["DeepseekV3DecoderLayer"]
506506
_skip_keys_device_placement = ["past_key_values"]
507+
_supports_flash_attn_3 = True
507508
_supports_flash_attn_2 = True
508509
_supports_sdpa = True
509510
_supports_flex_attn = True

src/transformers/models/diffllama/modeling_diffllama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
556556
supports_gradient_checkpointing = True
557557
_no_split_modules = ["DiffLlamaDecoderLayer"]
558558
_skip_keys_device_placement = ["past_key_values"]
559+
_supports_flash_attn_3 = True
559560
_supports_flash_attn_2 = True
560561
_supports_sdpa = True
561562
_supports_flex_attn = False

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
318318
supports_gradient_checkpointing = True
319319
_no_split_modules = ["GemmaDecoderLayer"]
320320
_skip_keys_device_placement = ["past_key_values"]
321+
_supports_flash_attn_3 = True
321322
_supports_flash_attn_2 = True
322323
_supports_sdpa = True
323324
_supports_flex_attn = True

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
339339
supports_gradient_checkpointing = True
340340
_no_split_modules = ["Gemma2DecoderLayer"]
341341
_skip_keys_device_placement = ["past_key_values"]
342+
_supports_flash_attn_3 = True
342343
_supports_flash_attn_2 = True
343344
_supports_sdpa = True
344345
_supports_flex_attn = True

0 commit comments

Comments
 (0)