-
Notifications
You must be signed in to change notification settings - Fork 12.4k
metal: SSM_SCAN performance #14743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
metal: SSM_SCAN performance #14743
Changes from all commits
ba74a24
8d5a25d
e16e24b
21db0b5
a5334f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32( | |
} | ||
|
||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part | ||
// TODO: optimize (e.g. by parallelizing over d_state) | ||
kernel void kernel_ssm_scan_f32_group( | ||
device const void * src0, | ||
device const void * src1, | ||
|
@@ -1762,10 +1761,15 @@ kernel void kernel_ssm_scan_f32_group( | |
device const void * src5, | ||
device const void * src6, | ||
device float * dst, | ||
threadgroup float * shared [[threadgroup(0)]], | ||
constant ggml_metal_kargs_ssm_scan & args, | ||
uint3 tgpig[[threadgroup_position_in_grid]], | ||
uint3 tpitg[[thread_position_in_threadgroup]], | ||
uint3 ntg[[threads_per_threadgroup]]) { | ||
uint3 tgpig[[threadgroup_position_in_grid]], | ||
uint3 tpitg[[thread_position_in_threadgroup]], | ||
ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
ushort tiisg[[thread_index_in_simdgroup]], | ||
ushort sgptg[[simdgroups_per_threadgroup]], | ||
uint3 tgpg[[threadgroups_per_grid]]) { | ||
|
||
const int64_t i1 = tgpig.x; | ||
const int64_t ir = tgpig.y; // current head | ||
const int64_t i3 = tgpig.z; // current seq | ||
|
@@ -1780,33 +1784,68 @@ kernel void kernel_ssm_scan_f32_group( | |
const int64_t ng = args.n_group; | ||
const int64_t n_t = args.n_seq_tokens; | ||
|
||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); | ||
const int64_t s_off = args.s_off; | ||
|
||
device const int32_t * ids = (device const int32_t *) src6; | ||
|
||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); | ||
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); | ||
|
||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} | ||
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); | ||
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); | ||
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); | ||
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); | ||
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); | ||
|
||
for (int64_t i2 = 0; i2 < n_t; ++i2) { | ||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} | ||
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} | ||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} | ||
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} | ||
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} | ||
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} | ||
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} | ||
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} | ||
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} | ||
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} | ||
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} | ||
|
||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; | ||
const float x_dt = x[0] * dt_soft_plus; | ||
const float dA = exp(dt_soft_plus * A[0]); | ||
float sumf = 0.0f; | ||
|
||
for (int64_t i0 = 0; i0 < nc; ++i0) { | ||
const int64_t i = i0 + i1*nc; | ||
const float state = (s0[i] * dA) + (B[i0] * x_dt); | ||
sumf += state * C[i0]; | ||
s[i] = state; | ||
const int64_t i = tpitg.x + i1*nc; | ||
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); | ||
s[i] = state; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it would be faster to store the intermediate states in a local variable instead of repeatedly in the destination buffer. I'm not very experienced with Metal (the naïve version you're starting from was pretty much my first Metal kernel), but I assume it should be possible? Unless I'm misunderstanding the memory model, each thread only handles a single state (as in I think this would only affect prompt processing speed, not really small batches, though. |
||
|
||
// Parallel sum: This relies on the fact that this kernel will be | ||
// dispatched with each threadgroup having (d_state, 1, 1) threads which | ||
// are subdivided into SIMD groups of size `sgptg`. The goal is to | ||
// compute y = sum({state * C[i] for i in range(d_state)}). | ||
// To parallelize this effectively, we first use simd_sum over each SIMD | ||
// group to compute the sum of each SIMD group, then place the result in | ||
// the SIMD group's indexed bucket in the shared memory. We then sum | ||
// over the individual group sums to compute the final sum. | ||
|
||
// Computed for each thread | ||
float sumf = state * C[tpitg.x]; | ||
|
||
// Sum the threads in the simd group => simd sum | ||
sumf = simd_sum(sumf); | ||
|
||
// Once per simd group, place the group sum into the shared buffer | ||
if (tiisg == 0) { | ||
shared[sgitg] = sumf; | ||
} | ||
|
||
// Wait for all threads in the threadgroup to reach this point. This | ||
// ensures that all elements of the shared buffer are populated with the | ||
// sum of the individual simd groups. | ||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
||
// Sum the simd buckets => threadgroup sum | ||
sumf = 0.0f; | ||
for (int64_t i0 = 0; i0 < sgptg; ++i0) { | ||
sumf += shared[i0]; | ||
} | ||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
||
y[0] = sumf; | ||
|
||
// recurse | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is actually necessary. I did this when initially trying to emulate the CUDA kernel better which passes this as an arg. It does seem like it might be slightly faster to avoid computing this in the kernel, though that could be offset by the latency of an additional
int64_t
being passed to the device?