Skip to content

Commit 3a71689

Browse files
authored
Mamba2 Chunk Scan & State (#950)
1 parent d182808 commit 3a71689

File tree

3 files changed

+482
-0
lines changed

3 files changed

+482
-0
lines changed

benchmarks/run.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,16 @@ class RunResult:
312312
"input_id": 1,
313313
},
314314
),
315+
"mamba2_chunk_scan": (
316+
"tritonbench.operators.mamba2_chunk_scan.operator",
317+
"examples.mamba2_chunk_scan",
318+
"helion_mamba2_chunk_scan_kernel",
319+
),
320+
"mamba2_chunk_state": (
321+
"tritonbench.operators.mamba2_chunk_state.operator",
322+
"examples.mamba2_chunk_state",
323+
"helion_mamba2_chunk_state_kernel",
324+
),
315325
}
316326

317327

@@ -596,6 +606,20 @@ class RunResult:
596606
"helion_blackwell_attention_tritonbench-speedup": "helion_speedup",
597607
"helion_blackwell_attention_tritonbench-accuracy": "helion_accuracy",
598608
},
609+
"mamba2_chunk_scan": {
610+
"eager": "baseline",
611+
"compile_speedup": "torch_compile_speedup",
612+
"compile_accuracy": "torch_compile_accuracy",
613+
"helion_mamba2_chunk_scan_kernel_speedup": "helion_speedup",
614+
"helion_mamba2_chunk_scan_kernel_accuracy": "helion_accuracy",
615+
},
616+
"mamba2_chunk_state": {
617+
"eager": "baseline",
618+
"compile_speedup": "torch_compile_speedup",
619+
"compile_accuracy": "torch_compile_accuracy",
620+
"helion_mamba2_chunk_state_kernel_speedup": "helion_speedup",
621+
"helion_mamba2_chunk_state_kernel_accuracy": "helion_accuracy",
622+
},
599623
}
600624

601625

examples/mamba2_chunk_scan.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""
2+
Mamba2 Chunk Scan Kernel
3+
========================
4+
5+
This code implements a chunked scan kernel as used for Mamba2
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
import functools
14+
15+
import torch
16+
17+
import helion
18+
from helion._testing import DEVICE
19+
from helion._testing import run_example
20+
import helion.language as hl
21+
22+
23+
# %%
24+
# Helion Kernel Implementation
25+
# ----------------------------
26+
@helion.kernel()
27+
def helion_mamba2_chunk_scan_kernel(
28+
cb: torch.Tensor,
29+
x: torch.Tensor,
30+
dt: torch.Tensor,
31+
dA_cumsum: torch.Tensor,
32+
C: torch.Tensor,
33+
prev_states: torch.Tensor,
34+
D: torch.Tensor,
35+
) -> torch.Tensor:
36+
"""
37+
Argument:
38+
cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
39+
x: (batch, seqlen, nheads, headdim)
40+
dt: (batch, nheads, nchunks, chunk_size)
41+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
42+
C: (batch, seqlen, ngroups, dstate)
43+
prev_states: (batch, nchunks, nheads, headdim, dstate)
44+
D: (nheads,)
45+
Return:
46+
out: (batch, seqlen, nheads, headdim)
47+
"""
48+
49+
batch, nchunks, ngroups, chunk_size, _ = cb.shape
50+
_, seqlen, nheads, headdim = x.shape
51+
_, _, _, dstate = C.shape
52+
assert nchunks == (seqlen + chunk_size - 1) // chunk_size
53+
54+
block_m = hl.register_block_size(chunk_size)
55+
block_n = hl.register_block_size(headdim)
56+
block_k = hl.register_block_size(64, 64)
57+
dstate = hl.specialize(dstate)
58+
59+
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
60+
assert x.shape == (batch, seqlen, nheads, headdim)
61+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
62+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
63+
assert C.shape == (batch, seqlen, ngroups, dstate)
64+
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
65+
assert D.shape == (nheads,)
66+
67+
dtype = cb.dtype
68+
accum_dtype = torch.float32
69+
assert (
70+
x.dtype
71+
== dt.dtype
72+
== dA_cumsum.dtype
73+
== C.dtype
74+
== prev_states.dtype
75+
== D.dtype
76+
== dtype
77+
)
78+
79+
out = torch.empty_like(x)
80+
81+
p = 1.44269504
82+
83+
for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile(
84+
[nheads, chunk_size, headdim, batch, nchunks],
85+
block_size=[1, block_m, block_n, 1, 1],
86+
):
87+
acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype)
88+
dA_cumsum_local_m = dA_cumsum[
89+
tile_b.begin, tile_h.begin, tile_c.begin, tile_m
90+
].to(torch.float32)
91+
scale_m_local = torch.exp2(dA_cumsum_local_m * p)
92+
93+
C_local = C[
94+
tile_b.begin,
95+
tile_m.index + tile_c.begin * chunk_size,
96+
tile_h.begin // (nheads // ngroups),
97+
:,
98+
]
99+
prev_states_local = prev_states[
100+
tile_b.begin, tile_c.begin, tile_h.begin, tile_n, :
101+
]
102+
acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o)
103+
acc_o *= scale_m_local[:, None]
104+
105+
for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k):
106+
cb_local = cb[
107+
tile_b.begin,
108+
tile_c.begin,
109+
tile_h.begin // (nheads // ngroups),
110+
tile_m,
111+
tile_k,
112+
]
113+
dA_cumsum_local_k = dA_cumsum[
114+
tile_b.begin, tile_h.begin, tile_c.begin, tile_k
115+
].to(torch.float32)
116+
cb_local *= torch.exp2(
117+
dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p
118+
)
119+
dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(
120+
torch.float32
121+
)
122+
cb_local = (cb_local * dt_local[None, :]).to(dtype)
123+
pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :]
124+
cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local))
125+
x_local = x[
126+
tile_b.begin,
127+
tile_c.begin * chunk_size + tile_k.index,
128+
tile_h.begin,
129+
tile_n,
130+
]
131+
acc_o = hl.dot(cb_local, x_local, acc=acc_o)
132+
133+
D_local = D[tile_h.begin].to(torch.float32)
134+
x_residual = x[
135+
tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n
136+
].to(torch.float32)
137+
acc_o += x_residual * D_local
138+
out[
139+
tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n
140+
] = acc_o.to(dtype=dtype)
141+
142+
return out
143+
144+
145+
# %%
146+
# Reference Function
147+
# -------------
148+
def ref_chunk_scan(
149+
cb: torch.Tensor,
150+
x: torch.Tensor,
151+
dt: torch.Tensor,
152+
dA_cumsum: torch.Tensor,
153+
C: torch.Tensor,
154+
prev_states: torch.Tensor,
155+
D: torch.Tensor,
156+
) -> torch.Tensor:
157+
"""
158+
Argument:
159+
cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
160+
x: (batch, seqlen, nheads, dhead)
161+
dt: (batch, nheads, nchunks, chunk_size)
162+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
163+
C: (batch, seqlen, ngroups, dstate)
164+
prev_states: (batch, nchunks, nheads, dhead, dstate)
165+
D: (nheads,)
166+
Return:
167+
out: (batch, seqlen, nheads, dhead)
168+
"""
169+
_, _, ngroups, _, _ = cb.shape
170+
batch, seqlen, nheads, dhead = x.shape
171+
# _, _, ngroups, dstate = B.shape
172+
# assert B.shape == (batch, seqlen, ngroups, dstate)
173+
_, _, nchunks, chunk_size = dt.shape
174+
dstate = C.shape[-1]
175+
assert seqlen == nchunks * chunk_size
176+
# assert C.shape == B.shape
177+
# B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
178+
C = torch.repeat_interleave(C, nheads // ngroups, dim=2)
179+
cb = torch.repeat_interleave(cb, nheads // ngroups, dim=2)
180+
# CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
181+
# rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
182+
# (batch, nheads, nchunks, chunksize, chunksize)
183+
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
184+
decay = torch.exp(dt_segment_sum)
185+
scores_decay = cb * decay.permute(0, 2, 1, 3, 4)
186+
causal_mask = torch.tril(
187+
torch.ones(chunk_size, chunk_size, device=x.device, dtype=torch.bool),
188+
diagonal=0,
189+
)
190+
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
191+
out = torch.einsum(
192+
"bchls,bhcs,bcshp->bclhp",
193+
scores_decay.to(x.dtype),
194+
dt.to(x.dtype),
195+
x.reshape(batch, nchunks, chunk_size, nheads, dhead),
196+
)
197+
# state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
198+
state_decay_out = torch.exp(dA_cumsum.permute(0, 2, 3, 1).unsqueeze(-1))
199+
out_prev = (
200+
torch.einsum(
201+
"bclhn,bchpn->bclhp",
202+
C.reshape(batch, nchunks, chunk_size, nheads, dstate),
203+
prev_states.to(C.dtype),
204+
)
205+
* state_decay_out
206+
)
207+
out = out + out_prev
208+
out = out.reshape(batch, seqlen, nheads, dhead)
209+
if D is not None:
210+
if D.dim() == 1:
211+
D = D.unsqueeze(-1)
212+
out = out + x * D
213+
return out
214+
215+
216+
# %%
217+
# Testing Function
218+
# -------------
219+
def test(
220+
init: str,
221+
batch: int,
222+
nheads: int,
223+
ngroups: int,
224+
seqlen: int,
225+
chunk_size: int,
226+
headdim: int,
227+
dstate: int,
228+
dtype: torch.dtype = torch.float16,
229+
) -> None:
230+
INIT = {
231+
"r": functools.partial(torch.randn, dtype=dtype, device=DEVICE),
232+
"u": functools.partial(torch.rand, dtype=dtype, device=DEVICE),
233+
"z": functools.partial(torch.zeros, dtype=dtype, device=DEVICE),
234+
"o": functools.partial(torch.ones, dtype=dtype, device=DEVICE),
235+
}
236+
nchunks = (seqlen + chunk_size - 1) // chunk_size
237+
idx = 0
238+
239+
def fn(*args: int) -> torch.Tensor:
240+
nonlocal idx
241+
ret = INIT[init[idx]](*args)
242+
idx += 1
243+
return ret
244+
245+
cb = fn(batch, nchunks, ngroups, chunk_size, chunk_size)
246+
x = fn(batch, seqlen, nheads, headdim)
247+
dt = fn(batch, nheads, nchunks, chunk_size)
248+
dA_cumsum = fn(batch, nheads, nchunks, chunk_size) # init range is too large
249+
C = fn(batch, seqlen, ngroups, dstate)
250+
prev_states = fn(batch, nchunks, nheads, headdim, dstate)
251+
D = fn(nheads)
252+
args = (cb, x, dt, dA_cumsum, C, prev_states, D)
253+
run_example(helion_mamba2_chunk_scan_kernel, ref_chunk_scan, args)
254+
255+
256+
# %%
257+
# Main Function
258+
# -----------
259+
def main() -> None:
260+
"""
261+
Main entry point that runs the attention kernel test with specific parameters.
262+
Tests with batch size 2, 32 heads, 1024 sequence length, and 64-dimensional heads using float16.
263+
"""
264+
test("zzzzzzz", 8, 80, 1, 4096, 256, 64, 128)
265+
test("zrzzzzr", 8, 80, 1, 4096, 256, 64, 128) # D * x
266+
test("zzzzrrz", 8, 80, 1, 4096, 256, 64, 128) # C * prev_state
267+
test("zzzrrrz", 8, 80, 1, 4096, 256, 64, 128) # C * prev_state * dA
268+
test("rrrzzzz", 8, 80, 1, 4096, 256, 64, 128) # cb * x * dt
269+
test("rrruzzz", 8, 80, 1, 4096, 256, 64, 128) # cb * x * dt * dA
270+
271+
272+
if __name__ == "__main__":
273+
main()

0 commit comments

Comments
 (0)