Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 54 additions & 37 deletions python/sgl_kernel_npu/sgl_kernel_npu/fla/layernorm_gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,49 @@


import torch
import torch.nn.functional as F
import torch_npu
import triton
import triton.language as tl
from sgl_kernel_npu.utils.triton_utils import get_device_properties
from einops import rearrange

def rms_norm(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True,
):
dtype = x.dtype
N = x.shape[-1]
weight = weight.float()
bias = bias.float() if bias is not None else None
mean = None
if upcast:
x = x.float()
z = z.float() if z is not None else z
if z is not None and not norm_before_gate:
x = x * F.silu(z)
if group_size is None:
weight = weight.to(x.dtype)
out, inv_rms = torch_npu.npu_rms_norm(x, weight, eps)
if bias is not None:
out = out + bias
rstd_flat = inv_rms.reshape(-1)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
rstd_flat = (rstd.squeeze(-1).transpose(0, 1).contiguous().view(-1))
if bias is not None:
out = out + bias
if z is not None and norm_before_gate:
out *= F.silu(z)
return out.to(dtype), mean, rstd_flat

# TODO:
# - Convert int32 comparison to fp32
Expand Down Expand Up @@ -106,43 +145,21 @@ def layer_norm_fwd_npu(
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

_, num_vectorcore = get_device_properties()
grid = (triton.cdiv(num_vectorcore, ngroups), ngroups)
_layer_norm_fwd_1pass_kernel_npu_smid[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
multibuffer=True,
if not is_rms_norm:
raise NotImplementedError("LayerNorm not implemented yet")
out_native, mean, rstd = rms_norm(
x=x,
weight=weight,
bias=bias,
z=z,
eps=eps,
group_size=None if group_size == N else group_size,
norm_before_gate=norm_before_gate,
upcast=True,
)
if out is not None:
out.copy_(out_native)
else:
out = out_native
return out, mean, rstd
48 changes: 30 additions & 18 deletions python/sgl_kernel_npu/sgl_kernel_npu/mamba/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ def causal_conv1d_fn_native(
if initial_states is None:
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
if x.ndim == 2:
x = x.unsqueeze(0)
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
if out.ndim == 3:
out = out.squeeze(0)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
Expand Down Expand Up @@ -105,26 +109,34 @@ def causal_conv1d_fn_npu(
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None

out_ref = []
out_ref_b = []
out_ref_b.append(
causal_conv1d_fn_native(
x,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[0]].unsqueeze(0),
initial_states=(
conv_states[cache_indices[0]].unsqueeze(0)
if has_initial_state[0]
else None
),
assert query_start_loc[-1] <= x.shape[-1], f"{query_start_loc=}, {x.shape=}"
for i in range(query_start_loc.numel() - 1):
out_ref_b.append(
causal_conv1d_fn_native(
x[..., query_start_loc[i] : query_start_loc[i + 1]],
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
initial_states=(
conv_states[cache_indices[i]].unsqueeze(0)
if has_initial_state[0]
else None
),
)
)
out_ref_tensor = torch.cat([t[0] for t in out_ref_b], dim=-1)
if x.shape[-1] > query_start_loc[-1]:
pad_seqlen = x.shape[-1] - query_start_loc[-1]
out_ref_tensor = torch.cat(
[
out_ref_tensor,
out_ref_tensor.new_zeros([*out_ref_tensor.shape[:-1], pad_seqlen]),
],
dim=-1,
)
)

out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor


Expand Down