From 49d1a9da35b696572212adbfe8d8b3006258bae6 Mon Sep 17 00:00:00 2001 From: POI-WX Date: Mon, 21 Oct 2024 07:28:36 +0000 Subject: [PATCH 1/9] fix bug of rope for npu and support the interleaved option --- .../internevo_ops/_rotary_embedding_npu.py | 90 +++++++++++++------ .../internevo_ops/rotary_embedding.py | 3 +- tests/internevo/test_rotary_embedding.py | 67 +++++++------- 3 files changed, 96 insertions(+), 64 deletions(-) diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 4e27d04..d9cc45a 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -2,7 +2,6 @@ import torch import torch_npu -from einops import rearrange __all__ = ["ApplyRotaryEmb"] @@ -38,38 +37,71 @@ def forward( assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) - re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - - cat_cos = torch.cat([re_cos, re_cos], -1) - cat_sin = torch.cat([re_sin, re_sin], -1) - - rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) - ctx.save_for_backward(cat_cos, cat_sin) + # "s d -> 1 s 1 d" + cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + x_ro = x[..., :rotary_dim] + ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.in_place = in_place - if in_place: - x[..., :rotary_dim].copy_(rot) - return x - else: - out = x.detach().clone() - if rotary_dim < head_dim and not in_place: + if interleaved: + x_in = torch.cat([x_ro[..., ::2], x_ro[..., 1::2]], dim=-1) + out_ro = torch_npu.npu_rotary_mul(x_in, cos, sin) + if in_place: + x_ro[..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)]) + x_ro[..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :]) + return x + out = torch.empty_like(x) + out[..., :rotary_dim][..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)]) + out[..., :rotary_dim][..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :]) + if rotary_dim < head_dim: out[..., rotary_dim:].copy_(x[..., rotary_dim:]) return out + else: + out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin) + if in_place: + x[..., :rotary_dim].copy_(out_ro) + return x + if rotary_dim < head_dim: + out = torch.empty_like(x) + out[..., :rotary_dim].copy_(out_ro) + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + return out + return out_ro @staticmethod - def backward(ctx, do): - cat_cos, cat_sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cat_cos.shape[-1] - - dx_out = torch_npu.npu_rotary_mul( - do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) - ) - if ctx.in_place: - do[..., :rotary_dim].copy_(dx_out) - return do, None, None, None, None + def backward(ctx, grad_out): + cos, sin = ctx.saved_tensors + rotary_dim = cos.shape[-1] + head_dim = grad_out.shape[-1] + grad_out_ro = grad_out[..., :rotary_dim] + if ctx.interleaved: + grad_out_in = torch.cat( + [grad_out_ro[..., ::2], grad_out_ro[..., 1::2]], dim=-1 + ) + grad_input_ro = torch_npu.npu_rotary_mul(grad_out_in, cos, torch.neg(sin)) + if ctx.in_place: + grad_out_ro[..., ::2].copy_(grad_input_ro[..., : int(rotary_dim / 2)]) + grad_out_ro[..., 1::2].copy_(grad_input_ro[..., int(rotary_dim / 2) :]) + return grad_out, None, None, None, None + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim][..., ::2].copy_( + grad_input_ro[..., : int(rotary_dim / 2)] + ) + grad_input[..., :rotary_dim][..., 1::2].copy_( + grad_input_ro[..., int(rotary_dim / 2) :] + ) + if rotary_dim < head_dim: + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None else: - dx = do.detach().clone() - dx[..., :rotary_dim].copy_(dx_out) - return dx, None, None, None, None + grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin)) + if ctx.in_place: + grad_out[..., :rotary_dim].copy_(grad_input_ro) + return grad_out, None, None, None, None + if rotary_dim < head_dim: + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim].copy_(grad_input_ro) + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None + return grad_input_ro, None, None, None, None diff --git a/deeplink_ext/internevo_ops/rotary_embedding.py b/deeplink_ext/internevo_ops/rotary_embedding.py index 1a2a36d..7764b9b 100644 --- a/deeplink_ext/internevo_ops/rotary_embedding.py +++ b/deeplink_ext/internevo_ops/rotary_embedding.py @@ -4,8 +4,7 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - # from ._rotary_embedding_npu import ApplyRotaryEmb - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb + from ._rotary_embedding_npu import ApplyRotaryEmb elif platform_type == PlatformType.TORCH_DIPU: from ._rotary_embedding_dipu import ApplyRotaryEmb else: diff --git a/tests/internevo/test_rotary_embedding.py b/tests/internevo/test_rotary_embedding.py index 981c2f0..722577e 100644 --- a/tests/internevo/test_rotary_embedding.py +++ b/tests/internevo/test_rotary_embedding.py @@ -8,40 +8,41 @@ def test_ApplyRotaryEmb(): input_dtype_list = [torch.float16, torch.bfloat16] - interleaved = False in_place_options = [False, True] + interleaved_options = [False, True] for input_dtype in input_dtype_list: for in_place in in_place_options: - input_ref = torch.randn( - 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True - ) - input_ext = input_ref.clone().detach().requires_grad_() - cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") - sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") + for interleaved in interleaved_options: + input_ref = torch.randn( + 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True + ) + input_ext = input_ref.clone().detach().requires_grad_() + cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") + sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") - output_ref, grad_ref = call_autograd_func( - ApplyRotaryEmbTorch, - "cuda", - input_dtype, - input_ref, - cos, - sin, - interleaved, - in_place, - ) - output_ext, grad_ext = call_autograd_func( - ApplyRotaryEmb, - "cuda", - input_dtype, - input_ext, - cos, - sin, - interleaved, - in_place, - ) - assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" - assert allclose( - grad_ref, grad_ext - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!" + output_ref, grad_ref = call_autograd_func( + ApplyRotaryEmbTorch, + "cuda", + input_dtype, + input_ref, + cos, + sin, + interleaved, + in_place, + ) + output_ext, grad_ext = call_autograd_func( + ApplyRotaryEmb, + "cuda", + input_dtype, + input_ext, + cos, + sin, + interleaved, + in_place, + ) + assert allclose( + output_ref, output_ext, rtol=1e-2, atol=5e-2 + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" + assert allclose( + grad_ref, grad_ext + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!" From 0eebc4b41bdf595aee565ae5c9ad397dad8ba918 Mon Sep 17 00:00:00 2001 From: POI-WX Date: Tue, 22 Oct 2024 04:19:22 +0000 Subject: [PATCH 2/9] optimize when interleaved is True --- .../internevo_ops/_rotary_embedding_npu.py | 82 ++++++++++++------- 1 file changed, 52 insertions(+), 30 deletions(-) diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index d9cc45a..28df9aa 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -2,10 +2,44 @@ import torch import torch_npu +from einops import rearrange, repeat __all__ = ["ApplyRotaryEmb"] +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 class ApplyRotaryEmb(torch.autograd.Function): """ @@ -37,27 +71,25 @@ def forward( assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) - # "s d -> 1 s 1 d" - cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) - sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) - x_ro = x[..., :rotary_dim] + if interleaved: + cos = cos[:seqlen] + sin = sin[:seqlen] + else: + # "s d -> 1 s 1 d" + cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.in_place = in_place if interleaved: - x_in = torch.cat([x_ro[..., ::2], x_ro[..., 1::2]], dim=-1) - out_ro = torch_npu.npu_rotary_mul(x_in, cos, sin) + out = apply_rotary_emb_torch(x, cos, sin, interleaved) if in_place: - x_ro[..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)]) - x_ro[..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :]) + x.copy_(out) return x - out = torch.empty_like(x) - out[..., :rotary_dim][..., ::2].copy_(out_ro[..., : int(rotary_dim / 2)]) - out[..., :rotary_dim][..., 1::2].copy_(out_ro[..., int(rotary_dim / 2) :]) - if rotary_dim < head_dim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - return out + else: + return out else: + x_ro = x[..., :rotary_dim] out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin) if in_place: x[..., :rotary_dim].copy_(out_ro) @@ -74,27 +106,17 @@ def backward(ctx, grad_out): cos, sin = ctx.saved_tensors rotary_dim = cos.shape[-1] head_dim = grad_out.shape[-1] - grad_out_ro = grad_out[..., :rotary_dim] if ctx.interleaved: - grad_out_in = torch.cat( - [grad_out_ro[..., ::2], grad_out_ro[..., 1::2]], dim=-1 + grad_input = apply_rotary_emb_torch( + grad_out, cos, torch.neg(sin), ctx.interleaved ) - grad_input_ro = torch_npu.npu_rotary_mul(grad_out_in, cos, torch.neg(sin)) if ctx.in_place: - grad_out_ro[..., ::2].copy_(grad_input_ro[..., : int(rotary_dim / 2)]) - grad_out_ro[..., 1::2].copy_(grad_input_ro[..., int(rotary_dim / 2) :]) + grad_out.copy_(grad_input) return grad_out, None, None, None, None - grad_input = torch.empty_like(grad_out) - grad_input[..., :rotary_dim][..., ::2].copy_( - grad_input_ro[..., : int(rotary_dim / 2)] - ) - grad_input[..., :rotary_dim][..., 1::2].copy_( - grad_input_ro[..., int(rotary_dim / 2) :] - ) - if rotary_dim < head_dim: - grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) - return grad_input, None, None, None, None + else: + return grad_input, None, None, None, None else: + grad_out_ro = grad_out[..., :rotary_dim] grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin)) if ctx.in_place: grad_out[..., :rotary_dim].copy_(grad_input_ro) From d69d17b3b45240279f2692dba63d70006e4e4f1b Mon Sep 17 00:00:00 2001 From: POI-WX Date: Tue, 22 Oct 2024 13:01:16 +0000 Subject: [PATCH 3/9] use fused rope op from mindspeed --- .../internevo_ops/_rotary_embedding_npu.py | 62 ++++++------------- 1 file changed, 20 insertions(+), 42 deletions(-) diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 28df9aa..75fb913 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -3,43 +3,11 @@ import torch import torch_npu from einops import rearrange, repeat +from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding __all__ = ["ApplyRotaryEmb"] -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 class ApplyRotaryEmb(torch.autograd.Function): """ @@ -72,8 +40,8 @@ def forward( assert sin.shape == (rotary_seqlen, rotary_dim // 2) if interleaved: - cos = cos[:seqlen] - sin = sin[:seqlen] + cos = repeat(cos[:seqlen].unsqueeze(0).unsqueeze(2), "... d -> ... (d 2)") + sin = repeat(sin[:seqlen].unsqueeze(0).unsqueeze(2), "... d -> ... (d 2)") else: # "s d -> 1 s 1 d" cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) @@ -82,12 +50,17 @@ def forward( ctx.interleaved = interleaved ctx.in_place = in_place if interleaved: - out = apply_rotary_emb_torch(x, cos, sin, interleaved) + x_ro = x[..., :rotary_dim] + out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 1) if in_place: - x.copy_(out) + x[..., :rotary_dim].copy_(out_ro) return x - else: + if rotary_dim < head_dim: + out = torch.empty_like(x) + out[..., :rotary_dim].copy_(out_ro) + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) return out + return out_ro else: x_ro = x[..., :rotary_dim] out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin) @@ -107,14 +80,19 @@ def backward(ctx, grad_out): rotary_dim = cos.shape[-1] head_dim = grad_out.shape[-1] if ctx.interleaved: - grad_input = apply_rotary_emb_torch( - grad_out, cos, torch.neg(sin), ctx.interleaved + grad_out_ro = grad_out[..., :rotary_dim] + grad_input_ro = npu_rotary_position_embedding( + grad_out_ro, cos, torch.neg(sin), 1 ) if ctx.in_place: - grad_out.copy_(grad_input) + grad_out[..., :rotary_dim].copy_(grad_input_ro) return grad_out, None, None, None, None - else: + if rotary_dim < head_dim: + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim].copy_(grad_input_ro) + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) return grad_input, None, None, None, None + return grad_input_ro, None, None, None, None else: grad_out_ro = grad_out[..., :rotary_dim] grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin)) From 6681c0ad09ab16271522bdbfa0d6184d39961c6c Mon Sep 17 00:00:00 2001 From: POI-WX Date: Wed, 23 Oct 2024 04:42:49 +0000 Subject: [PATCH 4/9] optimize --- deeplink_ext/internevo_ops/_rotary_embedding_npu.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 75fb913..3dbe507 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -2,7 +2,7 @@ import torch import torch_npu -from einops import rearrange, repeat +from einops import repeat from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding __all__ = ["ApplyRotaryEmb"] @@ -40,12 +40,11 @@ def forward( assert sin.shape == (rotary_seqlen, rotary_dim // 2) if interleaved: - cos = repeat(cos[:seqlen].unsqueeze(0).unsqueeze(2), "... d -> ... (d 2)") - sin = repeat(sin[:seqlen].unsqueeze(0).unsqueeze(2), "... d -> ... (d 2)") + cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (d 2)") + sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (d 2)") else: - # "s d -> 1 s 1 d" - cos = cos[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) - sin = sin[:seqlen].unsqueeze(0).unsqueeze(2).repeat(1, 1, 1, 2) + cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (2 d)") + sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (2 d)") ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.in_place = in_place From ee0baeb31a1bc37a6092e9a035eeeaa080280668 Mon Sep 17 00:00:00 2001 From: wangxing2 Date: Wed, 8 Jan 2025 20:56:56 +0800 Subject: [PATCH 5/9] using fused rms norm op with 8.0 RC3 software stack on npu --- deeplink_ext/interntrain_ops/rms_norm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/deeplink_ext/interntrain_ops/rms_norm.py b/deeplink_ext/interntrain_ops/rms_norm.py index 301ab9e..e6834cb 100644 --- a/deeplink_ext/interntrain_ops/rms_norm.py +++ b/deeplink_ext/interntrain_ops/rms_norm.py @@ -4,13 +4,9 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - # from ._mixed_rms_norm_npu import MixedFusedRMSNorm - # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + from ._mixed_rms_norm_npu import MixedFusedRMSNorm elif platform_type == PlatformType.TORCH_DIPU: - # from ._mixed_rms_norm_dipu import MixedFusedRMSNorm - # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + from ._mixed_rms_norm_dipu import MixedFusedRMSNorm else: raise ImportError From 5565fbc3032e9ffd9f115f93fa0f296c73d78dd7 Mon Sep 17 00:00:00 2001 From: wangxing2 Date: Thu, 9 Jan 2025 12:11:55 +0800 Subject: [PATCH 6/9] optimize --- deeplink_ext/internevo_ops/_rotary_embedding_npu.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 3dbe507..ba2237f 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -1,7 +1,6 @@ # Copyright (c) 2024, DeepLink. import torch -import torch_npu from einops import repeat from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding @@ -45,9 +44,11 @@ def forward( else: cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (2 d)") sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (2 d)") + ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.in_place = in_place + if interleaved: x_ro = x[..., :rotary_dim] out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 1) @@ -62,7 +63,7 @@ def forward( return out_ro else: x_ro = x[..., :rotary_dim] - out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin) + out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 0) if in_place: x[..., :rotary_dim].copy_(out_ro) return x @@ -78,6 +79,7 @@ def backward(ctx, grad_out): cos, sin = ctx.saved_tensors rotary_dim = cos.shape[-1] head_dim = grad_out.shape[-1] + if ctx.interleaved: grad_out_ro = grad_out[..., :rotary_dim] grad_input_ro = npu_rotary_position_embedding( @@ -94,7 +96,9 @@ def backward(ctx, grad_out): return grad_input_ro, None, None, None, None else: grad_out_ro = grad_out[..., :rotary_dim] - grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin)) + grad_input_ro = npu_rotary_position_embedding( + grad_out_ro, cos, torch.neg(sin), 0 + ) if ctx.in_place: grad_out[..., :rotary_dim].copy_(grad_input_ro) return grad_out, None, None, None, None From c39b8c522e77c5fec66f4e851ad5cfe3aa2c7082 Mon Sep 17 00:00:00 2001 From: wangxing2 Date: Thu, 9 Jan 2025 13:47:11 +0800 Subject: [PATCH 7/9] modify atol and rtol --- tests/internevo/test_rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/internevo/test_rotary_embedding.py b/tests/internevo/test_rotary_embedding.py index 722577e..a03bc95 100644 --- a/tests/internevo/test_rotary_embedding.py +++ b/tests/internevo/test_rotary_embedding.py @@ -41,7 +41,7 @@ def test_ApplyRotaryEmb(): in_place, ) assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 + output_ref, output_ext ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" assert allclose( grad_ref, grad_ext From a346f6c4936212e4edb2eca0ebe6e371a3a465bd Mon Sep 17 00:00:00 2001 From: wangxing2 Date: Tue, 21 Jan 2025 20:09:46 +0800 Subject: [PATCH 8/9] use the global attention mask and support sliding window local attention --- .../internevo_ops/_flash_attention_npu.py | 227 +++++++++--------- 1 file changed, 120 insertions(+), 107 deletions(-) diff --git a/deeplink_ext/internevo_ops/_flash_attention_npu.py b/deeplink_ext/internevo_ops/_flash_attention_npu.py index 37110b9..f540d70 100644 --- a/deeplink_ext/internevo_ops/_flash_attention_npu.py +++ b/deeplink_ext/internevo_ops/_flash_attention_npu.py @@ -12,6 +12,47 @@ "flash_attn_varlen_kvpacked_func", ] +_GLOBAL_ATTN_MASK = None + + +def set_attention_mask(attn_mask): + global _GLOBAL_ATTN_MASK + _GLOBAL_ATTN_MASK = attn_mask + + +def get_attention_mask(seqlen, causal, window_size): + global _GLOBAL_ATTN_MASK + + if _GLOBAL_ATTN_MASK is not None: + return _GLOBAL_ATTN_MASK + + # causal attention + if causal: + if seqlen > 2048: + _GLOBAL_ATTN_MASK = torch.triu( + torch.ones([2048, 2048], dtype=bool, device=torch.npu.current_device()), + diagonal=1, + ) + else: + _GLOBAL_ATTN_MASK = torch.triu( + torch.ones( + [seqlen, seqlen], dtype=bool, device=torch.npu.current_device() + ), + diagonal=1, + ) + + # sliding window attention + if window_size[0] >= 0 or window_size[1] >= 0: + _GLOBAL_ATTN_MASK = torch.tril( + torch.ones([seqlen, seqlen], dtype=bool, device=torch.npu.current_device()), + diagonal=-((seqlen - 1 if window_size[0] < 0 else window_size[0]) + 1), + ) + torch.triu( + torch.ones([seqlen, seqlen], dtype=bool, device=torch.npu.current_device()), + diagonal=(seqlen - 1 if window_size[1] < 0 else window_size[1]) + 1, + ) + + return _GLOBAL_ATTN_MASK + def flash_attn_func( q, @@ -32,22 +73,16 @@ def flash_attn_func( seqlen_k = k.shape[1] head_num = q.shape[-2] - if seqlen_q == seqlen_k and seqlen_q < 2048 and seqlen_k < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 + assert seqlen_q == seqlen_k + set_attention_mask(None) + attention_mask = get_attention_mask(seqlen_q, causal, window_size) + sparse_mode = 0 if attention_mask is None or seqlen_q <= 2048 else 4 - seqlen_q = min(seqlen_q, 2048) - seqlen_k = min(seqlen_k, 2048) - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + pre_tokens = seqlen_q - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = seqlen_q - 1 if window_size[0] < 0 else window_size[0] + next_tokens = seqlen_q - 1 if window_size[1] < 0 else window_size[1] out = torch_npu.npu_fusion_attention( q, @@ -58,8 +93,8 @@ def flash_attn_func( atten_mask=attention_mask, scale=softmax_scale, keep_prob=1 - dropout_p, - pre_tockens=seqlen_q, - next_tockens=0, + pre_tockens=pre_tokens, + next_tockens=next_tokens, sparse_mode=sparse_mode, )[0] @@ -89,22 +124,17 @@ def flash_attn_varlen_func( cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen_q, 2048) - seqlen_k = min(max_seqlen_k, 2048) - - if max_seqlen_q < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + + assert max_seqlen_q == max_seqlen_k + set_attention_mask(None) + attention_mask = get_attention_mask(max_seqlen_q, causal, window_size) + sparse_mode = 0 if attention_mask is None or max_seqlen_q <= 2048 else 4 + + pre_tokens = max_seqlen_q - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = max_seqlen_q - 1 if window_size[0] < 0 else window_size[0] + next_tokens = max_seqlen_q - 1 if window_size[1] < 0 else window_size[1] out = torch_npu.npu_fusion_attention( q, @@ -114,8 +144,8 @@ def flash_attn_varlen_func( "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=pre_tokens, + next_tockens=next_tokens, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, @@ -143,21 +173,15 @@ def flash_attn_qkvpacked_func( seqlen_qkv = qkv.shape[1] head_num = q.shape[-2] - if seqlen_qkv < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_qkv = min(qkv.shape[1], 2048) + set_attention_mask(None) + attention_mask = get_attention_mask(seqlen_qkv, causal, window_size) + sparse_mode = 0 if attention_mask is None or seqlen_qkv <= 2048 else 4 - attention_mask = ( - torch.triu( - torch.ones([seqlen_qkv, seqlen_qkv], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + pre_tokens = seqlen_qkv - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = seqlen_qkv - 1 if window_size[0] < 0 else window_size[0] + next_tokens = seqlen_qkv - 1 if window_size[1] < 0 else window_size[1] out = torch_npu.npu_fusion_attention( q, @@ -168,8 +192,8 @@ def flash_attn_qkvpacked_func( atten_mask=attention_mask, scale=softmax_scale, keep_prob=1 - dropout_p, - pre_tockens=seqlen_qkv, - next_tockens=0, + pre_tockens=pre_tokens, + next_tockens=next_tokens, sparse_mode=sparse_mode, )[0] @@ -192,26 +216,20 @@ def flash_attn_kvpacked_func( k = kv[:, :, 0] v = kv[:, :, 1] - s0 = q.shape[1] - s1 = kv.shape[1] + seqlen_q = q.shape[1] + seqlen_kv = kv.shape[1] head_num = q.shape[-2] - if s0 == s1 and s0 < 2048 and s1 < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_q = min(s0, 2048) - seqlen_k = min(s1, 2048) + assert seqlen_q == seqlen_kv + set_attention_mask(None) + attention_mask = get_attention_mask(seqlen_q, causal, window_size) + sparse_mode = 0 if attention_mask is None or seqlen_q <= 2048 else 4 - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + pre_tokens = seqlen_q - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = seqlen_q - 1 if window_size[0] < 0 else window_size[0] + next_tokens = seqlen_q - 1 if window_size[1] < 0 else window_size[1] out = torch_npu.npu_fusion_attention( q, @@ -222,8 +240,8 @@ def flash_attn_kvpacked_func( atten_mask=attention_mask, scale=softmax_scale, keep_prob=1 - dropout_p, - pre_tockens=seqlen_k, - next_tockens=0, + pre_tockens=pre_tokens, + next_tockens=next_tokens, sparse_mode=sparse_mode, )[0] @@ -247,32 +265,31 @@ def flash_attn_varlen_qkvpacked_func( q = qkv[:, 0] k = qkv[:, 1] v = qkv[:, 2] - n = q.shape[1] - if max_seqlen > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 + head_num = q.shape[1] + cu_seqlens_q = cu_seqlens[1:].tolist() cu_seqlens_k = cu_seqlens[1:].tolist() - seqlen = min(max_seqlen, 2048) - attention_mask = ( - torch.triu( - torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + + set_attention_mask(None) + attention_mask = get_attention_mask(max_seqlen, causal, window_size) + sparse_mode = 0 if attention_mask is None or max_seqlen <= 2048 else 4 + + pre_tokens = max_seqlen - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = max_seqlen - 1 if window_size[0] < 0 else window_size[0] + next_tokens = max_seqlen - 1 if window_size[1] < 0 else window_size[1] + out = torch_npu.npu_fusion_attention( q, k, v, - n, + head_num, "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=pre_tokens, + next_tockens=next_tokens, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, @@ -300,35 +317,31 @@ def flash_attn_varlen_kvpacked_func( softmax_scale = q.shape[-1] ** (-0.5) k = kv[:, 0] v = kv[:, 1] - n = q.shape[1] + head_num = q.shape[1] cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen_q, 2048) - seqlen_k = min(max_seqlen_k, 2048) - - if max_seqlen_q > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + + assert max_seqlen_q == max_seqlen_k + set_attention_mask(None) + attention_mask = get_attention_mask(max_seqlen_q, causal, window_size) + sparse_mode = 0 if attention_mask is None or max_seqlen_q <= 2048 else 4 + + pre_tokens = max_seqlen_q - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = max_seqlen_q - 1 if window_size[0] < 0 else window_size[0] + next_tokens = max_seqlen_k - 1 if window_size[1] < 0 else window_size[1] + out = torch_npu.npu_fusion_attention( q, k, v, - n, + head_num, "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=pre_tokens, + next_tockens=next_tokens, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, From 4b133df5be085e672ac32478f95add6b1144f5c4 Mon Sep 17 00:00:00 2001 From: wangxing2 Date: Tue, 21 Jan 2025 21:08:32 +0800 Subject: [PATCH 9/9] optimize code for npu --- .../internevo_ops/_flash_attention_npu.py | 24 +++++++------------ tests/internevo/test_flash_attention.py | 12 ++++++++++ .../internevo/test_varlen_flash_attention.py | 15 ++++++++++++ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/deeplink_ext/internevo_ops/_flash_attention_npu.py b/deeplink_ext/internevo_ops/_flash_attention_npu.py index f540d70..c68ffd6 100644 --- a/deeplink_ext/internevo_ops/_flash_attention_npu.py +++ b/deeplink_ext/internevo_ops/_flash_attention_npu.py @@ -12,14 +12,10 @@ "flash_attn_varlen_kvpacked_func", ] +# construct a global attention mask for npu _GLOBAL_ATTN_MASK = None -def set_attention_mask(attn_mask): - global _GLOBAL_ATTN_MASK - _GLOBAL_ATTN_MASK = attn_mask - - def get_attention_mask(seqlen, causal, window_size): global _GLOBAL_ATTN_MASK @@ -73,8 +69,7 @@ def flash_attn_func( seqlen_k = k.shape[1] head_num = q.shape[-2] - assert seqlen_q == seqlen_k - set_attention_mask(None) + assert seqlen_q == seqlen_k, "Npu currently only supports seqlen_q = seqlen_k." attention_mask = get_attention_mask(seqlen_q, causal, window_size) sparse_mode = 0 if attention_mask is None or seqlen_q <= 2048 else 4 @@ -125,8 +120,9 @@ def flash_attn_varlen_func( cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_k = cu_seqlens_k[1:].tolist() - assert max_seqlen_q == max_seqlen_k - set_attention_mask(None) + assert ( + max_seqlen_q == max_seqlen_k + ), "Npu currently only supports max_seqlen_q = max_seqlen_k." attention_mask = get_attention_mask(max_seqlen_q, causal, window_size) sparse_mode = 0 if attention_mask is None or max_seqlen_q <= 2048 else 4 @@ -173,7 +169,6 @@ def flash_attn_qkvpacked_func( seqlen_qkv = qkv.shape[1] head_num = q.shape[-2] - set_attention_mask(None) attention_mask = get_attention_mask(seqlen_qkv, causal, window_size) sparse_mode = 0 if attention_mask is None or seqlen_qkv <= 2048 else 4 @@ -220,8 +215,7 @@ def flash_attn_kvpacked_func( seqlen_kv = kv.shape[1] head_num = q.shape[-2] - assert seqlen_q == seqlen_kv - set_attention_mask(None) + assert seqlen_q == seqlen_kv, "Npu currently only supports seqlen_q = seqlen_kv." attention_mask = get_attention_mask(seqlen_q, causal, window_size) sparse_mode = 0 if attention_mask is None or seqlen_q <= 2048 else 4 @@ -270,7 +264,6 @@ def flash_attn_varlen_qkvpacked_func( cu_seqlens_q = cu_seqlens[1:].tolist() cu_seqlens_k = cu_seqlens[1:].tolist() - set_attention_mask(None) attention_mask = get_attention_mask(max_seqlen, causal, window_size) sparse_mode = 0 if attention_mask is None or max_seqlen <= 2048 else 4 @@ -321,8 +314,9 @@ def flash_attn_varlen_kvpacked_func( cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_k = cu_seqlens_k[1:].tolist() - assert max_seqlen_q == max_seqlen_k - set_attention_mask(None) + assert ( + max_seqlen_q == max_seqlen_k + ), "Npu currently only supports max_seqlen_q = max_seqlen_k." attention_mask = get_attention_mask(max_seqlen_q, causal, window_size) sparse_mode = 0 if attention_mask is None or max_seqlen_q <= 2048 else 4 diff --git a/tests/internevo/test_flash_attention.py b/tests/internevo/test_flash_attention.py index 5126551..b4c4771 100644 --- a/tests/internevo/test_flash_attention.py +++ b/tests/internevo/test_flash_attention.py @@ -14,6 +14,15 @@ flash_attn_func, ) +def clear_global_attn_mask_for_npu(): + # clear the global attention mask set by the latest test case + from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + platform_type = deeplink_ext_get_platform_type() + if platform_type == PlatformType.TORCH_NPU: + import deeplink_ext.internevo_ops._flash_attention_npu + deeplink_ext.internevo_ops._flash_attention_npu._GLOBAL_ATTN_MASK = None + else: + pass def test_flash_attn_qkvpacked_func_mha(): batch, seqlen, num_heads, headdim = [8, 32, 32, 64] @@ -46,6 +55,7 @@ def test_flash_attn_qkvpacked_func_mha(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) + clear_global_attn_mask_for_npu() def test_flash_attn_kvpacked_func_gqa(): @@ -83,6 +93,7 @@ def test_flash_attn_kvpacked_func_gqa(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) + clear_global_attn_mask_for_npu() def test_flash_attn_func_gqa(): @@ -128,3 +139,4 @@ def test_flash_attn_func_gqa(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) + clear_global_attn_mask_for_npu() diff --git a/tests/internevo/test_varlen_flash_attention.py b/tests/internevo/test_varlen_flash_attention.py index 97b8d64..d127241 100644 --- a/tests/internevo/test_varlen_flash_attention.py +++ b/tests/internevo/test_varlen_flash_attention.py @@ -14,6 +14,15 @@ flash_attn_varlen_func, ) +def clear_global_attn_mask_for_npu(): + # clear the global attention mask set by the latest test case + from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + platform_type = deeplink_ext_get_platform_type() + if platform_type == PlatformType.TORCH_NPU: + import deeplink_ext.internevo_ops._flash_attention_npu + deeplink_ext.internevo_ops._flash_attention_npu._GLOBAL_ATTN_MASK = None + else: + pass # fmt: off # latest sequence length is 20206-16110=4096 @@ -65,6 +74,7 @@ def test_flash_attn_varlen_qkvpacked_func_mha(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_qkvpacked_func_mha_long_max_seqlen(): @@ -109,6 +119,7 @@ def test_flash_attn_varlen_qkvpacked_func_mha_long_max_seqlen(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_kvpacked_func_gqa(): @@ -165,6 +176,7 @@ def test_flash_attn_varlen_kvpacked_func_gqa(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_kvpacked_func_gqa_long_max_seqlen(): @@ -223,6 +235,7 @@ def test_flash_attn_varlen_kvpacked_func_gqa_long_max_seqlen(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_func_gqa(): @@ -287,6 +300,7 @@ def test_flash_attn_varlen_func_gqa(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_func_gqa_long_max_seqlen(): @@ -353,3 +367,4 @@ def test_flash_attn_varlen_func_gqa_long_max_seqlen(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) + clear_global_attn_mask_for_npu()