Skip to content

Commit 0817add

Browse files
committed
fix: Correctly size the shared memory bufer and assert expected size relationships
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]>
1 parent e16e24b commit 0817add

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3017,7 +3017,15 @@ static bool ggml_metal_encode_node(
30173017

30183018
if (ne30 == 1) {
30193019
// Mamba-2
3020-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size
3020+
3021+
// One shared memory bucket for each simd group in the threadgroup
3022+
const size_t shmem_size = d_state / 32;
3023+
GGML_ASSERT(shmem_size * 32 == d_state);
3024+
3025+
// One thread pre element in d_state
3026+
GGML_ASSERT(d_state <= pipeline.maxTotalThreadsPerThreadgroup);
3027+
3028+
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
30213029
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
30223030
} else {
30233031
GGML_ASSERT(d_inner == 1);

0 commit comments

Comments
 (0)