Skip to content

Commit 8d5a25d

Browse files
committed
perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state
This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart <[email protected]>
1 parent ba74a24 commit 8d5a25d

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node(
29862986
/*.n_group =*/ n_group,
29872987
/*.n_seq_tokens =*/ n_seq_tokens,
29882988
/*.n_seqs =*/ n_seqs,
2989+
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
29892990
/*.nb01 =*/ nb01,
29902991
/*.nb02 =*/ nb02,
29912992
/*.nb03 =*/ nb03,
@@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node(
30163017

30173018
if (ne30 == 1) {
30183019
// Mamba-2
3019-
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3020+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size
3021+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
30203022
} else {
30213023
GGML_ASSERT(d_inner == 1);
30223024
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32(
17521752
}
17531753

17541754
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1755-
// TODO: optimize (e.g. by parallelizing over d_state)
17561755
kernel void kernel_ssm_scan_f32_group(
17571756
device const void * src0,
17581757
device const void * src1,
@@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group(
17621761
device const void * src5,
17631762
device const void * src6,
17641763
device float * dst,
1764+
threadgroup float * shared [[threadgroup(0)]],
17651765
constant ggml_metal_kargs_ssm_scan & args,
1766-
uint3 tgpig[[threadgroup_position_in_grid]],
1767-
uint3 tpitg[[thread_position_in_threadgroup]],
1768-
uint3 ntg[[threads_per_threadgroup]]) {
1766+
uint3 tgpig[[threadgroup_position_in_grid]],
1767+
uint3 tpitg[[thread_position_in_threadgroup]],
1768+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1769+
ushort tiisg[[thread_index_in_simdgroup]],
1770+
uint3 ntg[[threads_per_threadgroup]]) {
1771+
17691772
const int64_t i1 = tgpig.x;
17701773
const int64_t ir = tgpig.y; // current head
17711774
const int64_t i3 = tgpig.z; // current seq
@@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group(
17801783
const int64_t ng = args.n_group;
17811784
const int64_t n_t = args.n_seq_tokens;
17821785

1783-
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1786+
const int64_t s_off = args.s_off;
17841787

17851788
device const int32_t * ids = (device const int32_t *) src6;
17861789

@@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group(
17981801
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
17991802
const float x_dt = x[0] * dt_soft_plus;
18001803
const float dA = exp(dt_soft_plus * A[0]);
1804+
1805+
threadgroup_barrier(mem_flags::mem_threadgroup);
1806+
18011807
float sumf = 0.0f;
18021808

1803-
for (int64_t i0 = 0; i0 < nc; ++i0) {
1804-
const int64_t i = i0 + i1*nc;
1805-
const float state = (s0[i] * dA) + (B[i0] * x_dt);
1806-
sumf += state * C[i0];
1807-
s[i] = state;
1809+
const int64_t i = tpitg.x + i1*nc;
1810+
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
1811+
sumf += state * C[tpitg.x];
1812+
s[i] = state;
1813+
1814+
sumf = simd_sum(sumf);
1815+
1816+
threadgroup_barrier(mem_flags::mem_threadgroup);
1817+
1818+
// Use the shared buffer to hold the sum of each simd group
1819+
if (tiisg == 0) {
1820+
shared[sgitg] = sumf;
18081821
}
18091822

1823+
threadgroup_barrier(mem_flags::mem_threadgroup);
1824+
1825+
// Sum the simd buckets
1826+
sumf = shared[tiisg];
1827+
sumf = simd_sum(sumf);
1828+
18101829
y[0] = sumf;
18111830

18121831
// recurse

0 commit comments

Comments
 (0)