|
| 1 | +""" |
| 2 | +Mamba2 Chunk Scan Kernel |
| 3 | +======================== |
| 4 | +
|
| 5 | +This code implements a chunked scan kernel as used for Mamba2 |
| 6 | +""" |
| 7 | + |
| 8 | +# %% |
| 9 | +# Imports |
| 10 | +# ------- |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import functools |
| 14 | + |
| 15 | +import torch |
| 16 | + |
| 17 | +import helion |
| 18 | +from helion._testing import DEVICE |
| 19 | +from helion._testing import run_example |
| 20 | +import helion.language as hl |
| 21 | + |
| 22 | + |
| 23 | +# %% |
| 24 | +# Helion Kernel Implementation |
| 25 | +# ---------------------------- |
| 26 | +@helion.kernel() |
| 27 | +def helion_mamba2_chunk_scan_kernel( |
| 28 | + cb: torch.Tensor, |
| 29 | + x: torch.Tensor, |
| 30 | + dt: torch.Tensor, |
| 31 | + dA_cumsum: torch.Tensor, |
| 32 | + C: torch.Tensor, |
| 33 | + prev_states: torch.Tensor, |
| 34 | + D: torch.Tensor, |
| 35 | +) -> torch.Tensor: |
| 36 | + """ |
| 37 | + Argument: |
| 38 | + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) |
| 39 | + x: (batch, seqlen, nheads, headdim) |
| 40 | + dt: (batch, nheads, nchunks, chunk_size) |
| 41 | + dA_cumsum: (batch, nheads, nchunks, chunk_size) |
| 42 | + C: (batch, seqlen, ngroups, dstate) |
| 43 | + prev_states: (batch, nchunks, nheads, headdim, dstate) |
| 44 | + D: (nheads,) |
| 45 | + Return: |
| 46 | + out: (batch, seqlen, nheads, headdim) |
| 47 | + """ |
| 48 | + |
| 49 | + batch, nchunks, ngroups, chunk_size, _ = cb.shape |
| 50 | + _, seqlen, nheads, headdim = x.shape |
| 51 | + _, _, _, dstate = C.shape |
| 52 | + assert nchunks == (seqlen + chunk_size - 1) // chunk_size |
| 53 | + |
| 54 | + block_m = hl.register_block_size(chunk_size) |
| 55 | + block_n = hl.register_block_size(headdim) |
| 56 | + block_k = hl.register_block_size(64, 64) |
| 57 | + dstate = hl.specialize(dstate) |
| 58 | + |
| 59 | + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) |
| 60 | + assert x.shape == (batch, seqlen, nheads, headdim) |
| 61 | + assert dt.shape == (batch, nheads, nchunks, chunk_size) |
| 62 | + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) |
| 63 | + assert C.shape == (batch, seqlen, ngroups, dstate) |
| 64 | + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) |
| 65 | + assert D.shape == (nheads,) |
| 66 | + |
| 67 | + dtype = cb.dtype |
| 68 | + accum_dtype = torch.float32 |
| 69 | + assert ( |
| 70 | + x.dtype |
| 71 | + == dt.dtype |
| 72 | + == dA_cumsum.dtype |
| 73 | + == C.dtype |
| 74 | + == prev_states.dtype |
| 75 | + == D.dtype |
| 76 | + == dtype |
| 77 | + ) |
| 78 | + |
| 79 | + out = torch.empty_like(x) |
| 80 | + |
| 81 | + p = 1.44269504 |
| 82 | + |
| 83 | + for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( |
| 84 | + [nheads, chunk_size, headdim, batch, nchunks], |
| 85 | + block_size=[1, block_m, block_n, 1, 1], |
| 86 | + ): |
| 87 | + acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) |
| 88 | + dA_cumsum_local_m = dA_cumsum[ |
| 89 | + tile_b.begin, tile_h.begin, tile_c.begin, tile_m |
| 90 | + ].to(torch.float32) |
| 91 | + scale_m_local = torch.exp2(dA_cumsum_local_m * p) |
| 92 | + |
| 93 | + C_local = C[ |
| 94 | + tile_b.begin, |
| 95 | + tile_m.index + tile_c.begin * chunk_size, |
| 96 | + tile_h.begin // (nheads // ngroups), |
| 97 | + :, |
| 98 | + ] |
| 99 | + prev_states_local = prev_states[ |
| 100 | + tile_b.begin, tile_c.begin, tile_h.begin, tile_n, : |
| 101 | + ] |
| 102 | + acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o) |
| 103 | + acc_o *= scale_m_local[:, None] |
| 104 | + |
| 105 | + for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k): |
| 106 | + cb_local = cb[ |
| 107 | + tile_b.begin, |
| 108 | + tile_c.begin, |
| 109 | + tile_h.begin // (nheads // ngroups), |
| 110 | + tile_m, |
| 111 | + tile_k, |
| 112 | + ] |
| 113 | + dA_cumsum_local_k = dA_cumsum[ |
| 114 | + tile_b.begin, tile_h.begin, tile_c.begin, tile_k |
| 115 | + ].to(torch.float32) |
| 116 | + cb_local *= torch.exp2( |
| 117 | + dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p |
| 118 | + ) |
| 119 | + dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to( |
| 120 | + torch.float32 |
| 121 | + ) |
| 122 | + cb_local = (cb_local * dt_local[None, :]).to(dtype) |
| 123 | + pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] |
| 124 | + cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local)) |
| 125 | + x_local = x[ |
| 126 | + tile_b.begin, |
| 127 | + tile_c.begin * chunk_size + tile_k.index, |
| 128 | + tile_h.begin, |
| 129 | + tile_n, |
| 130 | + ] |
| 131 | + acc_o = hl.dot(cb_local, x_local, acc=acc_o) |
| 132 | + |
| 133 | + D_local = D[tile_h.begin].to(torch.float32) |
| 134 | + x_residual = x[ |
| 135 | + tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n |
| 136 | + ].to(torch.float32) |
| 137 | + acc_o += x_residual * D_local |
| 138 | + out[ |
| 139 | + tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n |
| 140 | + ] = acc_o.to(dtype=dtype) |
| 141 | + |
| 142 | + return out |
| 143 | + |
| 144 | + |
| 145 | +# %% |
| 146 | +# Reference Function |
| 147 | +# ------------- |
| 148 | +def ref_chunk_scan( |
| 149 | + cb: torch.Tensor, |
| 150 | + x: torch.Tensor, |
| 151 | + dt: torch.Tensor, |
| 152 | + dA_cumsum: torch.Tensor, |
| 153 | + C: torch.Tensor, |
| 154 | + prev_states: torch.Tensor, |
| 155 | + D: torch.Tensor, |
| 156 | +) -> torch.Tensor: |
| 157 | + """ |
| 158 | + Argument: |
| 159 | + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) |
| 160 | + x: (batch, seqlen, nheads, dhead) |
| 161 | + dt: (batch, nheads, nchunks, chunk_size) |
| 162 | + dA_cumsum: (batch, nheads, nchunks, chunk_size) |
| 163 | + C: (batch, seqlen, ngroups, dstate) |
| 164 | + prev_states: (batch, nchunks, nheads, dhead, dstate) |
| 165 | + D: (nheads,) |
| 166 | + Return: |
| 167 | + out: (batch, seqlen, nheads, dhead) |
| 168 | + """ |
| 169 | + _, _, ngroups, _, _ = cb.shape |
| 170 | + batch, seqlen, nheads, dhead = x.shape |
| 171 | + # _, _, ngroups, dstate = B.shape |
| 172 | + # assert B.shape == (batch, seqlen, ngroups, dstate) |
| 173 | + _, _, nchunks, chunk_size = dt.shape |
| 174 | + dstate = C.shape[-1] |
| 175 | + assert seqlen == nchunks * chunk_size |
| 176 | + # assert C.shape == B.shape |
| 177 | + # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) |
| 178 | + C = torch.repeat_interleave(C, nheads // ngroups, dim=2) |
| 179 | + cb = torch.repeat_interleave(cb, nheads // ngroups, dim=2) |
| 180 | + # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), |
| 181 | + # rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) |
| 182 | + # (batch, nheads, nchunks, chunksize, chunksize) |
| 183 | + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] |
| 184 | + decay = torch.exp(dt_segment_sum) |
| 185 | + scores_decay = cb * decay.permute(0, 2, 1, 3, 4) |
| 186 | + causal_mask = torch.tril( |
| 187 | + torch.ones(chunk_size, chunk_size, device=x.device, dtype=torch.bool), |
| 188 | + diagonal=0, |
| 189 | + ) |
| 190 | + scores_decay = scores_decay.masked_fill(~causal_mask, 0) |
| 191 | + out = torch.einsum( |
| 192 | + "bchls,bhcs,bcshp->bclhp", |
| 193 | + scores_decay.to(x.dtype), |
| 194 | + dt.to(x.dtype), |
| 195 | + x.reshape(batch, nchunks, chunk_size, nheads, dhead), |
| 196 | + ) |
| 197 | + # state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) |
| 198 | + state_decay_out = torch.exp(dA_cumsum.permute(0, 2, 3, 1).unsqueeze(-1)) |
| 199 | + out_prev = ( |
| 200 | + torch.einsum( |
| 201 | + "bclhn,bchpn->bclhp", |
| 202 | + C.reshape(batch, nchunks, chunk_size, nheads, dstate), |
| 203 | + prev_states.to(C.dtype), |
| 204 | + ) |
| 205 | + * state_decay_out |
| 206 | + ) |
| 207 | + out = out + out_prev |
| 208 | + out = out.reshape(batch, seqlen, nheads, dhead) |
| 209 | + if D is not None: |
| 210 | + if D.dim() == 1: |
| 211 | + D = D.unsqueeze(-1) |
| 212 | + out = out + x * D |
| 213 | + return out |
| 214 | + |
| 215 | + |
| 216 | +# %% |
| 217 | +# Testing Function |
| 218 | +# ------------- |
| 219 | +def test( |
| 220 | + init: str, |
| 221 | + batch: int, |
| 222 | + nheads: int, |
| 223 | + ngroups: int, |
| 224 | + seqlen: int, |
| 225 | + chunk_size: int, |
| 226 | + headdim: int, |
| 227 | + dstate: int, |
| 228 | + dtype: torch.dtype = torch.float16, |
| 229 | +) -> None: |
| 230 | + INIT = { |
| 231 | + "r": functools.partial(torch.randn, dtype=dtype, device=DEVICE), |
| 232 | + "u": functools.partial(torch.rand, dtype=dtype, device=DEVICE), |
| 233 | + "z": functools.partial(torch.zeros, dtype=dtype, device=DEVICE), |
| 234 | + "o": functools.partial(torch.ones, dtype=dtype, device=DEVICE), |
| 235 | + } |
| 236 | + nchunks = (seqlen + chunk_size - 1) // chunk_size |
| 237 | + idx = 0 |
| 238 | + |
| 239 | + def fn(*args: int) -> torch.Tensor: |
| 240 | + nonlocal idx |
| 241 | + ret = INIT[init[idx]](*args) |
| 242 | + idx += 1 |
| 243 | + return ret |
| 244 | + |
| 245 | + cb = fn(batch, nchunks, ngroups, chunk_size, chunk_size) |
| 246 | + x = fn(batch, seqlen, nheads, headdim) |
| 247 | + dt = fn(batch, nheads, nchunks, chunk_size) |
| 248 | + dA_cumsum = fn(batch, nheads, nchunks, chunk_size) # init range is too large |
| 249 | + C = fn(batch, seqlen, ngroups, dstate) |
| 250 | + prev_states = fn(batch, nchunks, nheads, headdim, dstate) |
| 251 | + D = fn(nheads) |
| 252 | + args = (cb, x, dt, dA_cumsum, C, prev_states, D) |
| 253 | + run_example(helion_mamba2_chunk_scan_kernel, ref_chunk_scan, args) |
| 254 | + |
| 255 | + |
| 256 | +# %% |
| 257 | +# Main Function |
| 258 | +# ----------- |
| 259 | +def main() -> None: |
| 260 | + """ |
| 261 | + Main entry point that runs the attention kernel test with specific parameters. |
| 262 | + Tests with batch size 2, 32 heads, 1024 sequence length, and 64-dimensional heads using float16. |
| 263 | + """ |
| 264 | + test("zzzzzzz", 8, 80, 1, 4096, 256, 64, 128) |
| 265 | + test("zrzzzzr", 8, 80, 1, 4096, 256, 64, 128) # D * x |
| 266 | + test("zzzzrrz", 8, 80, 1, 4096, 256, 64, 128) # C * prev_state |
| 267 | + test("zzzrrrz", 8, 80, 1, 4096, 256, 64, 128) # C * prev_state * dA |
| 268 | + test("rrrzzzz", 8, 80, 1, 4096, 256, 64, 128) # cb * x * dt |
| 269 | + test("rrruzzz", 8, 80, 1, 4096, 256, 64, 128) # cb * x * dt * dA |
| 270 | + |
| 271 | + |
| 272 | +if __name__ == "__main__": |
| 273 | + main() |
0 commit comments