|
30 | 30 |
|
31 | 31 |
|
32 | 32 | logger = logging.get_logger(__name__)
|
| 33 | +flash_attn_func = None |
33 | 34 |
|
34 | 35 |
|
35 | 36 | def _index_first_axis(tensor, indices):
|
@@ -92,6 +93,7 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
92 | 93 | output[indices] = hidden_states
|
93 | 94 | return output.view(batch, seqlen, *dim)
|
94 | 95 |
|
| 96 | + |
95 | 97 | FA_VERSION = None
|
96 | 98 | if is_flash_attn_2_available():
|
97 | 99 | from flash_attn import flash_attn_func as flash_attn_2_func
|
@@ -135,10 +137,19 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
135 | 137 |
|
136 | 138 | # patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
|
137 | 139 | if is_torch_npu_available():
|
138 |
| - from .integrations.npu_flash_attention import pad_input, unpad_input |
139 |
| - from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa |
140 |
| - from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func |
141 |
| - from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func |
| 140 | + from .integrations.npu_flash_attention import ( |
| 141 | + npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401 |
| 142 | + ) |
| 143 | + from .integrations.npu_flash_attention import ( |
| 144 | + npu_flash_attn_func as flash_attn_func, |
| 145 | + ) |
| 146 | + from .integrations.npu_flash_attention import ( |
| 147 | + npu_flash_attn_varlen_func as flash_attn_varlen_func, |
| 148 | + ) |
| 149 | + from .integrations.npu_flash_attention import ( |
| 150 | + pad_input, |
| 151 | + unpad_input, |
| 152 | + ) |
142 | 153 |
|
143 | 154 |
|
144 | 155 | _flash_supports_window_size = False
|
@@ -279,9 +290,7 @@ def _upad_input(
|
279 | 290 | else:
|
280 | 291 | # The -q_len: slice assumes left padding.
|
281 | 292 | attention_mask = attention_mask[:, -query_length:]
|
282 |
| - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func( |
283 |
| - query_layer, attention_mask |
284 |
| - ) |
| 293 | + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) |
285 | 294 |
|
286 | 295 | return (
|
287 | 296 | query_layer,
|
|
0 commit comments