Skip to content

Commit dd7977d

Browse files
Dequant in chunked prefill
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent 860f3e0 commit dd7977d

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,9 @@ def __init__(self,
427427
self.page_size = self.kv_cache_spec.block_size
428428

429429
if self.chunked_prefill_enabled:
430+
workspace_dtype = self.model_config.dtype
431+
if cache_config.cache_dtype.startswith("fp8"):
432+
workspace_dtype = current_platform.fp8_dtype()
430433
self.chunked_prefill_workspace_size = min(
431434
# Max sure there is enough for 8 full length request or at least
432435
# 4 pages of cache per request
@@ -447,7 +450,7 @@ def __init__(self,
447450
self.chunked_prefill_workspace = torch.empty(
448451
(self.chunked_prefill_workspace_size,
449452
self.model_config.get_head_size()),
450-
dtype=self.model_config.dtype,
453+
dtype=workspace_dtype,
451454
device=device,
452455
)
453456

@@ -1022,6 +1025,8 @@ def _compute_prefill_context(
10221025
iters = len(prefill_metadata.chunked_context.seq_tot)
10231026
workspace = prefill_metadata.chunked_context.workspace
10241027

1028+
fp8_attention = self.kv_cache_dtype.startswith("fp8")
1029+
10251030
for i in range(iters):
10261031
toks = prefill_metadata.chunked_context.seq_tot[i]
10271032

@@ -1039,6 +1044,16 @@ def _compute_prefill_context(
10391044
k_pe = workspace[:toks]\
10401045
[..., self.kv_lora_rank:].unsqueeze(1)
10411046

1047+
if fp8_attention:
1048+
target_dtype = self.kv_b_proj.weight.dtype
1049+
kv_c_normed_dequant = torch.empty_like(kv_c_normed,
1050+
dtype=target_dtype)
1051+
k_pe_dequant = torch.empty_like(k_pe, dtype=target_dtype)
1052+
ops.convert_fp8(kv_c_normed_dequant, kv_c_normed)
1053+
ops.convert_fp8(k_pe_dequant, k_pe)
1054+
kv_c_normed = kv_c_normed_dequant
1055+
k_pe = k_pe_dequant
1056+
10421057
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
10431058
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
10441059
k_nope, v = kv_nope\
@@ -1155,7 +1170,7 @@ def forward(
11551170
# same expert outputs.
11561171
return output.fill_(0)
11571172

1158-
fp8_attention = self.kv_cache_dtype == "fp8"
1173+
fp8_attention = self.kv_cache_dtype.startswith("fp8")
11591174

11601175
num_actual_toks = attn_metadata.num_actual_tokens
11611176

@@ -1191,6 +1206,9 @@ def forward(
11911206
scale=layer._k_scale,
11921207
)
11931208

1209+
if fp8_attention:
1210+
kv_cache = kv_cache.view(current_platform.fp8_dtype())
1211+
11941212
if has_prefill:
11951213
output[num_decode_tokens:] = self._forward_prefill(
11961214
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
@@ -1208,7 +1226,6 @@ def forward(
12081226
decode_ql_nope = decode_ql_nope.transpose(0, 1)
12091227

12101228
if fp8_attention:
1211-
kv_cache = kv_cache.view(torch.float8_e4m3fn)
12121229
ql_nope_shape = decode_ql_nope.shape
12131230
decode_ql_nope, _ = ops.scaled_fp8_quant(
12141231
decode_ql_nope.reshape([

0 commit comments

Comments
 (0)