Skip to content

Commit 699954b

Browse files
bottlerfacebook-github-bot
authored andcommitted
pad dequantized paged fp8 kv with zeros (#4780)
Summary: Pull Request resolved: #4780 X-link: facebookresearch/FBGEMM#1803 Pad zeros after the end of used sequences to avoid nans in flash attention 3, in the dequantization of fp8 paged kv-cache. This is analogous to the non-paged case which was tackled in D69522001. Reviewed By: TheEpicDolphin, Aya-ZIbra Differential Revision: D80977902 fbshipit-source-id: d95c232cdf89b0601c071951a7cc77fdfe31e405
1 parent 3870dd4 commit 699954b

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2719,8 +2719,8 @@ __global__ void dequantize_fp8_cache_kernel_paged(
27192719
auto max_t = kv_seqlen[b];
27202720
27212721
// one warp per T/H
2722-
for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
2723-
t_h += blockDim.y * gridDim.y) {
2722+
auto t_h = threadIdx.y + blockIdx.y * blockDim.y;
2723+
for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) {
27242724
auto h = t_h % N_KVH;
27252725
auto t = t_h / N_KVH;
27262726
@@ -2774,6 +2774,29 @@ __global__ void dequantize_fp8_cache_kernel_paged(
27742774
*reinterpret_cast<uint2*>(&row_v_dq[4 * threadIdx.x]) =
27752775
*reinterpret_cast<uint2*>(&kv_dq.vals[2]);
27762776
}
2777+
2778+
// zero out the rest of the page, because FA3 can be affected by
2779+
// NaN values beyond the sequence length.
2780+
max_t = (max_t + page_size - 1) / page_size * page_size;
2781+
for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) {
2782+
if (4 * threadIdx.x >= D_H) {
2783+
continue;
2784+
}
2785+
auto h = t_h % N_KVH;
2786+
auto t = t_h / N_KVH;
2787+
2788+
int page_logical_idx = t / page_size;
2789+
int page_offset = t % page_size;
2790+
int page_physical_idx =
2791+
block_tables[b * block_tables_b_stride + page_logical_idx];
2792+
int physical_t = page_physical_idx * page_size + page_offset;
2793+
2794+
auto* row_k_dq = &cache_K_dq[0][physical_t][h][0];
2795+
auto* row_v_dq = &cache_V_dq[0][physical_t][h][0];
2796+
2797+
memset(&row_k_dq[4 * threadIdx.x], 0, sizeof(uint2));
2798+
memset(&row_v_dq[4 * threadIdx.x], 0, sizeof(uint2));
2799+
}
27772800
}
27782801
#endif
27792802

0 commit comments

Comments
 (0)