Skip to content

metal: SSM_SCAN performance #14743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ typedef struct {
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
int64_t s_off;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is actually necessary. I did this when initially trying to emulate the CUDA kernel better which passes this as an arg. It does seem like it might be slightly faster to avoid computing this in the kernel, though that could be offset by the latency of an additional int64_t being passed to the device?

uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
Expand Down
12 changes: 11 additions & 1 deletion ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node(
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
Expand Down Expand Up @@ -3016,7 +3017,16 @@ static bool ggml_metal_encode_node(

if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];

// One shared memory bucket for each simd group in the threadgroup
const int64_t shmem_size = d_state / 32;
GGML_ASSERT(shmem_size * 32 == d_state);

// One thread pre element in d_state
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);

[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
Expand Down
73 changes: 56 additions & 17 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32(
}

// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
Expand All @@ -1762,10 +1761,15 @@ kernel void kernel_ssm_scan_f32_group(
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {

const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
Expand All @@ -1780,33 +1784,68 @@ kernel void kernel_ssm_scan_f32_group(
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;

const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
const int64_t s_off = args.s_off;

device const int32_t * ids = (device const int32_t *) src6;

device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);

device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);

for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}

const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
float sumf = 0.0f;

for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
const int64_t i = tpitg.x + i1*nc;
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
s[i] = state;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be faster to store the intermediate states in a local variable instead of repeatedly in the destination buffer.

I'm not very experienced with Metal (the naïve version you're starting from was pretty much my first Metal kernel), but I assume it should be possible?

Unless I'm misunderstanding the memory model, each thread only handles a single state (as in s[i] always refers to the same place, but differs between threads).

I think this would only affect prompt processing speed, not really small batches, though.


// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.

// Computed for each thread
float sumf = state * C[tpitg.x];

// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);

// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}

// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
threadgroup_barrier(mem_flags::mem_threadgroup);

// Sum the simd buckets => threadgroup sum
sumf = 0.0f;
for (int64_t i0 = 0; i0 < sgptg; ++i0) {
sumf += shared[i0];
}

threadgroup_barrier(mem_flags::mem_threadgroup);

y[0] = sumf;

// recurse
Expand Down
Loading