diff --git a/deeplink_ext/internevo_ops/_flash_attention_npu.py b/deeplink_ext/internevo_ops/_flash_attention_npu.py index 37110b9..2929ee7 100644 --- a/deeplink_ext/internevo_ops/_flash_attention_npu.py +++ b/deeplink_ext/internevo_ops/_flash_attention_npu.py @@ -25,6 +25,11 @@ def flash_attn_func( deterministic=False, return_attn_probs=False, ): + assert window_size == ( + -1, + -1, + ), "Npu currently does not support sliding window attention" + assert alibi_slopes is None, "Npu currently does not support ALiBi." if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -32,17 +37,13 @@ def flash_attn_func( seqlen_k = k.shape[1] head_num = q.shape[-2] - if seqlen_q == seqlen_k and seqlen_q < 2048 and seqlen_k < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_q = min(seqlen_q, 2048) - seqlen_k = min(seqlen_k, 2048) + assert seqlen_q == seqlen_k, "Npu currently only supports seqlen_q = seqlen_k." + sparse_mode = 2 if seqlen_q > 2048 else 0 + seqlen = min(seqlen_q, 2048) attention_mask = ( torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), + torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device), diagonal=1, ) if causal @@ -81,25 +82,28 @@ def flash_attn_varlen_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, - block_table=None, ): + assert window_size == ( + -1, + -1, + ), "Npu currently does not support sliding window attention" + assert alibi_slopes is None, "Npu currently does not support ALiBi." if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) head_num = q.shape[-2] cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen_q, 2048) - seqlen_k = min(max_seqlen_k, 2048) - if max_seqlen_q < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 + assert ( + max_seqlen_q == max_seqlen_k + ), "Npu currently only supports max_seqlen_q = max_seqlen_k." + sparse_mode = 2 if max_seqlen_q > 2048 else 0 + max_seqlen = min(max_seqlen_q, 2048) attention_mask = ( torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), + torch.ones([max_seqlen, max_seqlen], dtype=torch.bool, device=q.device), diagonal=1, ) if causal @@ -114,8 +118,8 @@ def flash_attn_varlen_func( "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=q.shape[0], + next_tockens=0, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, @@ -134,6 +138,11 @@ def flash_attn_qkvpacked_func( deterministic=False, return_attn_probs=False, ): + assert window_size == ( + -1, + -1, + ), "Npu currently does not support sliding window attention" + assert alibi_slopes is None, "Npu currently does not support ALiBi." if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) q = qkv[:, :, 0] @@ -143,16 +152,12 @@ def flash_attn_qkvpacked_func( seqlen_qkv = qkv.shape[1] head_num = q.shape[-2] - if seqlen_qkv < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_qkv = min(qkv.shape[1], 2048) + sparse_mode = 2 if seqlen_qkv > 2048 else 0 + seqlen = min(seqlen_qkv, 2048) attention_mask = ( torch.triu( - torch.ones([seqlen_qkv, seqlen_qkv], dtype=torch.bool, device=q.device), + torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device), diagonal=1, ) if causal @@ -187,26 +192,27 @@ def flash_attn_kvpacked_func( deterministic=False, return_attn_probs=False, ): + assert window_size == ( + -1, + -1, + ), "Npu currently does not support sliding window attention" + assert alibi_slopes is None, "Npu currently does not support ALiBi." if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) k = kv[:, :, 0] v = kv[:, :, 1] - s0 = q.shape[1] - s1 = kv.shape[1] + seqlen_q = q.shape[1] + seqlen_kv = kv.shape[1] head_num = q.shape[-2] - if s0 == s1 and s0 < 2048 and s1 < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_q = min(s0, 2048) - seqlen_k = min(s1, 2048) + assert seqlen_q == seqlen_kv, "Npu currently only supports seqlen_q = seqlen_kv." + sparse_mode = 2 if seqlen_q > 2048 else 0 + seqlen = min(seqlen_q, 2048) attention_mask = ( torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), + torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device), diagonal=1, ) if causal @@ -222,7 +228,7 @@ def flash_attn_kvpacked_func( atten_mask=attention_mask, scale=softmax_scale, keep_prob=1 - dropout_p, - pre_tockens=seqlen_k, + pre_tockens=seqlen_q, next_tockens=0, sparse_mode=sparse_mode, )[0] @@ -242,37 +248,42 @@ def flash_attn_varlen_qkvpacked_func( deterministic=False, return_attn_probs=False, ): + assert window_size == ( + -1, + -1, + ), "Npu currently does not support sliding window attention" + assert alibi_slopes is None, "Npu currently does not support ALiBi." if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) q = qkv[:, 0] k = qkv[:, 1] v = qkv[:, 2] - n = q.shape[1] - if max_seqlen > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 + head_num = q.shape[1] + cu_seqlens_q = cu_seqlens[1:].tolist() cu_seqlens_k = cu_seqlens[1:].tolist() - seqlen = min(max_seqlen, 2048) + + sparse_mode = 2 if max_seqlen > 2048 else 0 + max_seqlen = min(max_seqlen, 2048) attention_mask = ( torch.triu( - torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device), + torch.ones([max_seqlen, max_seqlen], dtype=torch.bool, device=q.device), diagonal=1, ) if causal else None ) + out = torch_npu.npu_fusion_attention( q, k, v, - n, + head_num, "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=q.shape[0], + next_tockens=0, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, @@ -296,39 +307,44 @@ def flash_attn_varlen_kvpacked_func( deterministic=False, return_attn_probs=False, ): + assert window_size == ( + -1, + -1, + ), "Npu currently does not support sliding window attention" + assert alibi_slopes is None, "Npu currently does not support ALiBi." if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) k = kv[:, 0] v = kv[:, 1] - n = q.shape[1] + head_num = q.shape[1] cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen_q, 2048) - seqlen_k = min(max_seqlen_k, 2048) - if max_seqlen_q > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 + assert ( + max_seqlen_q == max_seqlen_k + ), "Npu currently only supports max_seqlen_q = max_seqlen_k." + sparse_mode = 2 if max_seqlen_q > 2048 else 0 + max_seqlen = min(max_seqlen_q, 2048) attention_mask = ( torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), + torch.ones([max_seqlen, max_seqlen], dtype=torch.bool, device=q.device), diagonal=1, ) if causal else None ) + out = torch_npu.npu_fusion_attention( q, k, v, - n, + head_num, "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=q.shape[0], + next_tockens=0, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 4e27d04..ba2237f 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -1,8 +1,8 @@ # Copyright (c) 2024, DeepLink. import torch -import torch_npu -from einops import rearrange +from einops import repeat +from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding __all__ = ["ApplyRotaryEmb"] @@ -38,38 +38,73 @@ def forward( assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) - re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - - cat_cos = torch.cat([re_cos, re_cos], -1) - cat_sin = torch.cat([re_sin, re_sin], -1) + if interleaved: + cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (d 2)") + sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (d 2)") + else: + cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (2 d)") + sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (2 d)") - rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) - ctx.save_for_backward(cat_cos, cat_sin) + ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.in_place = in_place - if in_place: - x[..., :rotary_dim].copy_(rot) - return x + + if interleaved: + x_ro = x[..., :rotary_dim] + out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 1) + if in_place: + x[..., :rotary_dim].copy_(out_ro) + return x + if rotary_dim < head_dim: + out = torch.empty_like(x) + out[..., :rotary_dim].copy_(out_ro) + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + return out + return out_ro else: - out = x.detach().clone() - if rotary_dim < head_dim and not in_place: + x_ro = x[..., :rotary_dim] + out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 0) + if in_place: + x[..., :rotary_dim].copy_(out_ro) + return x + if rotary_dim < head_dim: + out = torch.empty_like(x) + out[..., :rotary_dim].copy_(out_ro) out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - return out + return out + return out_ro @staticmethod - def backward(ctx, do): - cat_cos, cat_sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cat_cos.shape[-1] + def backward(ctx, grad_out): + cos, sin = ctx.saved_tensors + rotary_dim = cos.shape[-1] + head_dim = grad_out.shape[-1] - dx_out = torch_npu.npu_rotary_mul( - do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) - ) - if ctx.in_place: - do[..., :rotary_dim].copy_(dx_out) - return do, None, None, None, None + if ctx.interleaved: + grad_out_ro = grad_out[..., :rotary_dim] + grad_input_ro = npu_rotary_position_embedding( + grad_out_ro, cos, torch.neg(sin), 1 + ) + if ctx.in_place: + grad_out[..., :rotary_dim].copy_(grad_input_ro) + return grad_out, None, None, None, None + if rotary_dim < head_dim: + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim].copy_(grad_input_ro) + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None + return grad_input_ro, None, None, None, None else: - dx = do.detach().clone() - dx[..., :rotary_dim].copy_(dx_out) - return dx, None, None, None, None + grad_out_ro = grad_out[..., :rotary_dim] + grad_input_ro = npu_rotary_position_embedding( + grad_out_ro, cos, torch.neg(sin), 0 + ) + if ctx.in_place: + grad_out[..., :rotary_dim].copy_(grad_input_ro) + return grad_out, None, None, None, None + if rotary_dim < head_dim: + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim].copy_(grad_input_ro) + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None + return grad_input_ro, None, None, None, None diff --git a/deeplink_ext/internevo_ops/rotary_embedding.py b/deeplink_ext/internevo_ops/rotary_embedding.py index 1a2a36d..7764b9b 100644 --- a/deeplink_ext/internevo_ops/rotary_embedding.py +++ b/deeplink_ext/internevo_ops/rotary_embedding.py @@ -4,8 +4,7 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - # from ._rotary_embedding_npu import ApplyRotaryEmb - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb + from ._rotary_embedding_npu import ApplyRotaryEmb elif platform_type == PlatformType.TORCH_DIPU: from ._rotary_embedding_dipu import ApplyRotaryEmb else: diff --git a/deeplink_ext/interntrain_ops/rms_norm.py b/deeplink_ext/interntrain_ops/rms_norm.py index 301ab9e..e6834cb 100644 --- a/deeplink_ext/interntrain_ops/rms_norm.py +++ b/deeplink_ext/interntrain_ops/rms_norm.py @@ -4,13 +4,9 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - # from ._mixed_rms_norm_npu import MixedFusedRMSNorm - # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + from ._mixed_rms_norm_npu import MixedFusedRMSNorm elif platform_type == PlatformType.TORCH_DIPU: - # from ._mixed_rms_norm_dipu import MixedFusedRMSNorm - # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + from ._mixed_rms_norm_dipu import MixedFusedRMSNorm else: raise ImportError diff --git a/tests/internevo/test_rotary_embedding.py b/tests/internevo/test_rotary_embedding.py index 981c2f0..a03bc95 100644 --- a/tests/internevo/test_rotary_embedding.py +++ b/tests/internevo/test_rotary_embedding.py @@ -8,40 +8,41 @@ def test_ApplyRotaryEmb(): input_dtype_list = [torch.float16, torch.bfloat16] - interleaved = False in_place_options = [False, True] + interleaved_options = [False, True] for input_dtype in input_dtype_list: for in_place in in_place_options: - input_ref = torch.randn( - 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True - ) - input_ext = input_ref.clone().detach().requires_grad_() - cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") - sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") + for interleaved in interleaved_options: + input_ref = torch.randn( + 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True + ) + input_ext = input_ref.clone().detach().requires_grad_() + cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") + sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") - output_ref, grad_ref = call_autograd_func( - ApplyRotaryEmbTorch, - "cuda", - input_dtype, - input_ref, - cos, - sin, - interleaved, - in_place, - ) - output_ext, grad_ext = call_autograd_func( - ApplyRotaryEmb, - "cuda", - input_dtype, - input_ext, - cos, - sin, - interleaved, - in_place, - ) - assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" - assert allclose( - grad_ref, grad_ext - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!" + output_ref, grad_ref = call_autograd_func( + ApplyRotaryEmbTorch, + "cuda", + input_dtype, + input_ref, + cos, + sin, + interleaved, + in_place, + ) + output_ext, grad_ext = call_autograd_func( + ApplyRotaryEmb, + "cuda", + input_dtype, + input_ext, + cos, + sin, + interleaved, + in_place, + ) + assert allclose( + output_ref, output_ext + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" + assert allclose( + grad_ref, grad_ext + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!"