@@ -178,7 +178,8 @@ class FMHAPrefillCached {
178
178
static bool can_implement (Arguments const &args) {
179
179
bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or
180
180
(args.mode == gemm::GemmUniversalMode::kBatched && rank (ProblemShape{}) == 4 );
181
- return mode_implementable;
181
+ bool valid_page_size = !PagedKV || (args.mainloop .page_size >= QK_BLK_N && args.mainloop .page_size % QK_BLK_N == 0 );
182
+ return mode_implementable && valid_page_size;
182
183
}
183
184
184
185
static int get_workspace_size (Arguments const &args) { return 0 ; }
@@ -314,10 +315,22 @@ class FMHAPrefillCached {
314
315
}
315
316
auto & prefetch_K = (seq_len_kv_cache == 0 ) ? tiled_prefetch_k: tiled_prefetch_k_cache;
316
317
auto & pKgK1_ = (seq_len_kv_cache == 0 ) ? pKgK: pKgK_cache;
318
+
319
+ int cached_nblock = 0 ;
320
+ if constexpr (PagedKV) {
321
+ if (seq_len_kv_cache != 0 ) {
322
+ int curr_batch_pages = is_var_len ? mainloop_params.num_pages_per_seq [batch_coord + 1 ] - mainloop_params.num_pages_per_seq [batch_coord]
323
+ : ceil_div (seq_len_kv_cache, mainloop_params.page_size );
324
+ int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq [batch_coord] : batch_coord * curr_batch_pages;
325
+ cached_nblock = mainloop_params.ptr_page_table [
326
+ batch_offset // page table for this batch
327
+ ] * tiles_per_page; // base block idx of physical page
328
+ }
329
+ }
317
330
// The headsize for both cached and non-cached version is the same
318
331
for (int j = 0 ; j < size<4 >(pKgK1_); j++) {
319
332
CUTLASS_PRAGMA_UNROLL
320
- for (int i = 0 ; i < DispatchPolicy::Stages; i++) {
333
+ for (int i = cached_nblock ; i < cached_nblock + DispatchPolicy::Stages; i++) {
321
334
prefetch (prefetch_K, pKgK1_ (_, _, _ , i, j));
322
335
}
323
336
}
@@ -345,18 +358,6 @@ class FMHAPrefillCached {
345
358
346
359
bool is_KV_cache = nblock < nblock_cache;
347
360
348
- int cached_nblock = nblock;
349
- if constexpr (PagedKV) {
350
- if (is_KV_cache) {
351
- // get physical page idx from page table
352
- cached_nblock = params.mainloop .ptr_page_table [
353
- batch_coord * params.mainloop .num_pages_per_seq + // page table for this batch
354
- nblock * QK_BLK_N / params.mainloop .page_size // nblock (tile idx) to logical page idx
355
- ] * tiles_per_page + // base block idx of physical page
356
- nblock % tiles_per_page; // offset within page
357
- }
358
- }
359
-
360
361
// 1) Load KV (performed inside mmaQK)
361
362
auto gK_ = is_KV_cache ? gK_cache (_, _, cached_nblock, _) : gK (_, _, nblock - nblock_cache, _);
362
363
auto gV_ = is_KV_cache ? gV_cache (_, _, cached_nblock) : gV (_, _, nblock - nblock_cache);
@@ -372,8 +373,32 @@ class FMHAPrefillCached {
372
373
// prefetching it the same way as cutlass K matrix does not make sense
373
374
auto & tiled_prefetch_v_ = is_KV_cache ? tiled_prefetch_v_cache : tiled_prefetch_v;
374
375
auto & pVgV_ = is_KV_cache ? pVgV_cache : pVgV;
375
- for (int i=0 ; i < size<1 >(pVgV); i++) {
376
- prefetch (tiled_prefetch_v_, pVgV_cache (_, i, _ , nblock - (!is_KV_cache) * nblock_cache));
376
+ int v_prefetch_idx = is_KV_cache ? PagedKV ? cached_nblock : nblock
377
+ : nblock - nblock_cache;
378
+ for (int i = 0 ; i < size<1 >(pVgV_); i++) {
379
+ prefetch (tiled_prefetch_v_, pVgV_ (_, i, _ , v_prefetch_idx));
380
+ }
381
+
382
+ int next_cached_nblock = nblock + 1 ;
383
+ bool is_next_KV_cache = next_cached_nblock < nblock_cache;
384
+ if constexpr (PagedKV) {
385
+ if (is_next_KV_cache) {
386
+ int curr_batch_pages = is_var_len ? mainloop_params.num_pages_per_seq [batch_coord + 1 ] - mainloop_params.num_pages_per_seq [batch_coord]
387
+ : ceil_div (seq_len_kv_cache, mainloop_params.page_size );
388
+ int next_page_logical_idx = next_cached_nblock * QK_BLK_N / params.mainloop .page_size ;
389
+ int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq [batch_coord] : batch_coord * curr_batch_pages;
390
+ bool valid_page = next_page_logical_idx < curr_batch_pages;
391
+ // get physical page idx from page table
392
+ if (valid_page) {
393
+ next_cached_nblock = params.mainloop .ptr_page_table [
394
+ batch_offset + // page table for this batch
395
+ next_page_logical_idx // nblock (tile idx) to logical page idx
396
+ ] * tiles_per_page + // base block idx of physical page
397
+ next_cached_nblock % tiles_per_page; // offset within page
398
+ } else {
399
+ next_cached_nblock = curr_batch_pages * tiles_per_page; // push idx out of bounds to respect the boundary between batches
400
+ }
401
+ }
377
402
}
378
403
379
404
// 4) Fused softmax
@@ -382,16 +407,26 @@ class FMHAPrefillCached {
382
407
383
408
// 5) Perform GEMM O = S*V
384
409
collective_mma.template mmaPV <VSlicer>(out_reg, tSr, gV_ , out_reg, mainloop_params, is_KV_cache);
385
-
410
+
411
+ // Prefetch the next Q tile
412
+ CUTLASS_PRAGMA_UNROLL
413
+ for (int i = 0 ; i < size<3 >(pQgQ); i++) {
414
+ prefetch (tiled_prefetch_q, pQgQ (_, _, _, i));
415
+ }
416
+
417
+ is_KV_cache = is_next_KV_cache;
418
+ cached_nblock = next_cached_nblock;
386
419
// Prefetch the next K tile
387
420
// there is no need to gaurd it with if statememt as prefetch will ignore out of bound reading
388
421
389
422
bool sel_prefetch_k = (nblock + DispatchPolicy::Stages) < nblock_cache;
390
423
auto & prefetch_k_selector = sel_prefetch_k ? tiled_prefetch_k_cache: tiled_prefetch_k;
391
424
auto & pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK;
425
+ int k_prefetch_idx = sel_prefetch_k ? PagedKV ? cached_nblock : nblock + DispatchPolicy::Stages
426
+ : nblock + DispatchPolicy::Stages - nblock_cache;
392
427
CUTLASS_PRAGMA_UNROLL
393
428
for (int j = 0 ; j < size<4 >(pKgK_); j++) {
394
- prefetch (prefetch_k_selector, pKgK_ (_, _, _, (nblock + DispatchPolicy::Stages) - (!sel_prefetch_k) * nblock_cache , j));
429
+ prefetch (prefetch_k_selector, pKgK_ (_, _, _, k_prefetch_idx , j));
395
430
}
396
431
barrier_wait (barrier_scope);
397
432
}
@@ -406,8 +441,8 @@ class FMHAPrefillCached {
406
441
collective_mma.mmaQK (tSr, gQ , gK (_, _, nblock_new - 1 , _), tSr, ceil_div (head_size_qk, QK_BLK_K), mainloop_params, false );
407
442
// we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big,
408
443
// prefetching it the same way as cutlass K matrix does not make sense
409
- for (int i= 0 ; i< size<1 >(pVgV); i++) {
410
- prefetch (tiled_prefetch_v, pVgV (_, i, _ , nblock_new - 1 ));
444
+ for (int i = 0 ; i< size<1 >(pVgV); i++) {
445
+ prefetch (tiled_prefetch_v, pVgV (_, i, _ , nblock_new - 1 ));
411
446
}
412
447
// mask the elements of each tile where j > i
413
448
const int item_id = thread_idx % SubgroupSize;
@@ -420,7 +455,7 @@ class FMHAPrefillCached {
420
455
CUTLASS_PRAGMA_UNROLL
421
456
for (int row = 0 ; row < Vec; row++, row_idx++) { // 8
422
457
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
423
- tSr (row, m, n) = -INFINITY;
458
+ tSr (row, m, n) = ElementAccumulator{ -INFINITY} ;
424
459
}
425
460
}
426
461
}
0 commit comments