@@ -427,6 +427,9 @@ def __init__(self,
427
427
self .page_size = self .kv_cache_spec .block_size
428
428
429
429
if self .chunked_prefill_enabled :
430
+ workspace_dtype = self .model_config .dtype
431
+ if cache_config .cache_dtype .startswith ("fp8" ):
432
+ workspace_dtype = current_platform .fp8_dtype ()
430
433
self .chunked_prefill_workspace_size = min (
431
434
# Max sure there is enough for 8 full length request or at least
432
435
# 4 pages of cache per request
@@ -447,7 +450,7 @@ def __init__(self,
447
450
self .chunked_prefill_workspace = torch .empty (
448
451
(self .chunked_prefill_workspace_size ,
449
452
self .model_config .get_head_size ()),
450
- dtype = self . model_config . dtype ,
453
+ dtype = workspace_dtype ,
451
454
device = device ,
452
455
)
453
456
@@ -1022,6 +1025,8 @@ def _compute_prefill_context(
1022
1025
iters = len (prefill_metadata .chunked_context .seq_tot )
1023
1026
workspace = prefill_metadata .chunked_context .workspace
1024
1027
1028
+ fp8_attention = self .kv_cache_dtype .startswith ("fp8" )
1029
+
1025
1030
for i in range (iters ):
1026
1031
toks = prefill_metadata .chunked_context .seq_tot [i ]
1027
1032
@@ -1039,6 +1044,16 @@ def _compute_prefill_context(
1039
1044
k_pe = workspace [:toks ]\
1040
1045
[..., self .kv_lora_rank :].unsqueeze (1 )
1041
1046
1047
+ if fp8_attention :
1048
+ target_dtype = self .kv_b_proj .weight .dtype
1049
+ kv_c_normed_dequant = torch .empty_like (kv_c_normed ,
1050
+ dtype = target_dtype )
1051
+ k_pe_dequant = torch .empty_like (k_pe , dtype = target_dtype )
1052
+ ops .convert_fp8 (kv_c_normed_dequant , kv_c_normed )
1053
+ ops .convert_fp8 (k_pe_dequant , k_pe )
1054
+ kv_c_normed = kv_c_normed_dequant
1055
+ k_pe = k_pe_dequant
1056
+
1042
1057
kv_nope = self .kv_b_proj (kv_c_normed )[0 ].view ( \
1043
1058
- 1 , self .num_heads , self .qk_nope_head_dim + self .v_head_dim )
1044
1059
k_nope , v = kv_nope \
@@ -1155,7 +1170,7 @@ def forward(
1155
1170
# same expert outputs.
1156
1171
return output .fill_ (0 )
1157
1172
1158
- fp8_attention = self .kv_cache_dtype == "fp8"
1173
+ fp8_attention = self .kv_cache_dtype . startswith ( "fp8" )
1159
1174
1160
1175
num_actual_toks = attn_metadata .num_actual_tokens
1161
1176
@@ -1191,6 +1206,9 @@ def forward(
1191
1206
scale = layer ._k_scale ,
1192
1207
)
1193
1208
1209
+ if fp8_attention :
1210
+ kv_cache = kv_cache .view (current_platform .fp8_dtype ())
1211
+
1194
1212
if has_prefill :
1195
1213
output [num_decode_tokens :] = self ._forward_prefill (
1196
1214
prefill_q , prefill_k_c_normed , prefill_k_pe , kv_cache ,
@@ -1208,7 +1226,6 @@ def forward(
1208
1226
decode_ql_nope = decode_ql_nope .transpose (0 , 1 )
1209
1227
1210
1228
if fp8_attention :
1211
- kv_cache = kv_cache .view (torch .float8_e4m3fn )
1212
1229
ql_nope_shape = decode_ql_nope .shape
1213
1230
decode_ql_nope , _ = ops .scaled_fp8_quant (
1214
1231
decode_ql_nope .reshape ([
0 commit comments