Skip to content

Commit 9ac2929

Browse files
author
K11OntheBoat
committed
Support deepseekv3 cache transfer for PD deploy
1 parent 2099708 commit 9ac2929

File tree

18 files changed

+711
-404
lines changed

18 files changed

+711
-404
lines changed

custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu

Lines changed: 125 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,24 @@
1313
// limitations under the License.
1414
#pragma once
1515

16-
#include "helper.h"
1716
#include "mla_cache_kernel.cuh"
17+
#include "helper.h"
18+
#include "remote_cache_kv_ipc.h"
1819

1920
template <paddle::DataType T>
2021
std::vector<paddle::Tensor> PrefillMLAWriteCache(
21-
const AppendAttnMetaData& meta_data,
22-
const paddle::Tensor& kv_nope,
23-
const paddle::Tensor& kv_pe,
24-
const paddle::Tensor& seq_lens,
25-
const paddle::Tensor& seq_lens_decoder,
26-
const paddle::Tensor& batch_id_per_token,
27-
const paddle::Tensor& cu_seqlens_q,
28-
const paddle::Tensor& block_tables,
29-
const int max_seq_len,
30-
cudaStream_t& stream,
31-
paddle::Tensor* kv_cache) {
22+
const AppendAttnMetaData& meta_data,
23+
const paddle::Tensor& kv_nope,
24+
const paddle::Tensor& kv_pe,
25+
const paddle::Tensor& seq_lens,
26+
const paddle::Tensor& seq_lens_decoder,
27+
const paddle::Tensor& batch_id_per_token,
28+
const paddle::Tensor& cu_seqlens_q,
29+
const paddle::Tensor& block_tables,
30+
const paddle::optional<paddle::Tensor>& kv_signal_data,
31+
const int max_seq_len,
32+
cudaStream_t& stream,
33+
paddle::Tensor* kv_cache) {
3234
typedef PDTraits<T> traits_;
3335
typedef typename traits_::DataType DataType_;
3436
typedef typename traits_::data_t data_t;
@@ -50,8 +52,10 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
5052

5153
prefill_absorb_cache_kernel<DataType_, PackSize>
5254
<<<grid_size, blocksize, 0, stream>>>(
53-
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
54-
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
55+
reinterpret_cast<DataType_*>(
56+
const_cast<data_t*>(kv_nope.data<data_t>())),
57+
reinterpret_cast<DataType_*>(
58+
const_cast<data_t*>(kv_pe.data<data_t>())),
5559
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
5660
block_tables.data<int>(),
5761
batch_id_per_token.data<int>(),
@@ -65,6 +69,33 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
6569
pe_size,
6670
block_size,
6771
elem_nums);
72+
73+
const char* fmt_write_cache_completed_signal_str =
74+
std::getenv("FLAGS_fmt_write_cache_completed_signal");
75+
const char* FLAGS_use_pd_disaggregation_per_chunk =
76+
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
77+
78+
if (fmt_write_cache_completed_signal_str &&
79+
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
80+
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
81+
if (FLAGS_use_pd_disaggregation_per_chunk &&
82+
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
83+
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
84+
cudaLaunchHostFunc(
85+
stream,
86+
&(RemoteCacheKvIpc::
87+
save_cache_kv_complete_signal_layerwise_per_query),
88+
(void*)nullptr);
89+
} else {
90+
if (kv_signal_data) {
91+
cudaLaunchHostFunc(
92+
stream,
93+
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
94+
(void*)(const_cast<int64_t*>(
95+
kv_signal_data.get().data<int64_t>())));
96+
}
97+
}
98+
}
6899
return {};
69100
}
70101

@@ -77,6 +108,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
77108
const paddle::Tensor& batch_id_per_token,
78109
const paddle::Tensor& cu_seqlens_q,
79110
const paddle::Tensor& block_tables,
111+
const paddle::optional<paddle::Tensor>& kv_signal_data,
80112
const std::string& cache_quant_type_str,
81113
const int max_seq_len) {
82114
cudaStream_t stream = kv_pe.stream();
@@ -85,7 +117,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
85117
const auto& kv_pe_dims = kv_pe.dims();
86118
const auto& kv_cache_dims = kv_cache.dims();
87119
meta_data.kv_num_heads = kv_cache_dims[1];
88-
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
120+
const auto nope_size =
121+
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
89122
meta_data.token_nums = kv_nope_dims[0];
90123
meta_data.head_dims = kv_cache_dims[3];
91124
meta_data.head_dims_v = nope_size;
@@ -95,49 +128,53 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
95128
meta_data.batch_size = seq_lens_decoder.dims()[0];
96129
switch (kv_pe.dtype()) {
97130
case paddle::DataType::BFLOAT16: {
98-
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
99-
kv_nope,
100-
kv_pe,
101-
seq_lens,
102-
seq_lens_decoder,
103-
batch_id_per_token,
104-
cu_seqlens_q,
105-
block_tables,
106-
max_seq_len,
107-
stream,
108-
const_cast<paddle::Tensor*>(&kv_cache));
131+
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
132+
meta_data,
133+
kv_nope,
134+
kv_pe,
135+
seq_lens,
136+
seq_lens_decoder,
137+
batch_id_per_token,
138+
cu_seqlens_q,
139+
block_tables,
140+
kv_signal_data,
141+
max_seq_len,
142+
stream,
143+
const_cast<paddle::Tensor*>(&kv_cache));
109144
}
110145
case paddle::DataType::FLOAT16: {
111-
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
112-
kv_nope,
113-
kv_pe,
114-
seq_lens,
115-
seq_lens_decoder,
116-
batch_id_per_token,
117-
cu_seqlens_q,
118-
block_tables,
119-
max_seq_len,
120-
stream,
121-
const_cast<paddle::Tensor*>(&kv_cache));
146+
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
147+
meta_data,
148+
kv_nope,
149+
kv_pe,
150+
seq_lens,
151+
seq_lens_decoder,
152+
batch_id_per_token,
153+
cu_seqlens_q,
154+
block_tables,
155+
kv_signal_data,
156+
max_seq_len,
157+
stream,
158+
const_cast<paddle::Tensor*>(&kv_cache));
122159
}
123160
}
124161
return {};
125162
}
126163

127164
template <paddle::DataType T>
128165
std::vector<paddle::Tensor> DecodeMLAWriteCache(
129-
const AppendAttnMetaData& meta_data,
130-
const paddle::Tensor& kv_nope,
131-
const paddle::Tensor& kv_pe,
132-
const paddle::Tensor& seq_lens,
133-
const paddle::Tensor& seq_lens_encoder,
134-
const paddle::Tensor& batch_id_per_token,
135-
const paddle::Tensor& cu_seqlens_q,
136-
const paddle::Tensor& block_tables,
137-
const int max_seq_len,
138-
const bool speculate_decoder,
139-
cudaStream_t& stream,
140-
paddle::Tensor* kv_cache) {
166+
const AppendAttnMetaData& meta_data,
167+
const paddle::Tensor& kv_nope,
168+
const paddle::Tensor& kv_pe,
169+
const paddle::Tensor& seq_lens,
170+
const paddle::Tensor& seq_lens_encoder,
171+
const paddle::Tensor& batch_id_per_token,
172+
const paddle::Tensor& cu_seqlens_q,
173+
const paddle::Tensor& block_tables,
174+
const int max_seq_len,
175+
const bool speculate_decoder,
176+
cudaStream_t& stream,
177+
paddle::Tensor* kv_cache) {
141178
typedef PDTraits<T> traits_;
142179
typedef typename traits_::DataType DataType_;
143180
typedef typename traits_::data_t data_t;
@@ -154,15 +191,16 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
154191
const int blocksize = 128;
155192
int grid_size = 1;
156193

157-
158194
if (speculate_decoder) {
159195
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
160196
const int pack_num = elem_nums / PackSize;
161197
GetNumBlocks<128>(pack_num, &grid_size);
162198
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
163199
<<<grid_size, blocksize, 0, stream>>>(
164-
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
165-
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
200+
reinterpret_cast<DataType_*>(
201+
const_cast<data_t*>(kv_nope.data<data_t>())),
202+
reinterpret_cast<DataType_*>(
203+
const_cast<data_t*>(kv_pe.data<data_t>())),
166204
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
167205
block_tables.data<int>(),
168206
batch_id_per_token.data<int>(),
@@ -182,8 +220,10 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
182220
GetNumBlocks<128>(pack_num, &grid_size);
183221
decode_absorb_cache_kernel<DataType_, PackSize>
184222
<<<grid_size, blocksize, 0, stream>>>(
185-
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
186-
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
223+
reinterpret_cast<DataType_*>(
224+
const_cast<data_t*>(kv_nope.data<data_t>())),
225+
reinterpret_cast<DataType_*>(
226+
const_cast<data_t*>(kv_pe.data<data_t>())),
187227
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
188228
block_tables.data<int>(),
189229
cu_seqlens_q.data<int>(),
@@ -218,7 +258,8 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
218258
const auto& kv_pe_dims = kv_pe.dims();
219259
const auto& kv_cache_dims = kv_cache.dims();
220260
meta_data.kv_num_heads = kv_cache_dims[1];
221-
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
261+
const auto nope_size =
262+
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
222263
meta_data.token_nums = kv_nope_dims[0];
223264
meta_data.head_dims = kv_cache_dims[3];
224265
meta_data.head_dims_v = nope_size;
@@ -228,38 +269,39 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
228269
meta_data.batch_size = seq_lens_encoder.dims()[0];
229270
switch (kv_pe.dtype()) {
230271
case paddle::DataType::BFLOAT16: {
231-
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
232-
kv_nope,
233-
kv_pe,
234-
seq_lens,
235-
seq_lens_encoder,
236-
batch_id_per_token,
237-
cu_seqlens_q,
238-
block_tables,
239-
max_seq_len,
240-
speculate_decoder,
241-
stream,
242-
const_cast<paddle::Tensor*>(&kv_cache));
272+
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
273+
meta_data,
274+
kv_nope,
275+
kv_pe,
276+
seq_lens,
277+
seq_lens_encoder,
278+
batch_id_per_token,
279+
cu_seqlens_q,
280+
block_tables,
281+
max_seq_len,
282+
speculate_decoder,
283+
stream,
284+
const_cast<paddle::Tensor*>(&kv_cache));
243285
}
244286
case paddle::DataType::FLOAT16: {
245-
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
246-
kv_nope,
247-
kv_pe,
248-
seq_lens,
249-
seq_lens_encoder,
250-
batch_id_per_token,
251-
cu_seqlens_q,
252-
block_tables,
253-
max_seq_len,
254-
speculate_decoder,
255-
stream,
256-
const_cast<paddle::Tensor*>(&kv_cache));
287+
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
288+
meta_data,
289+
kv_nope,
290+
kv_pe,
291+
seq_lens,
292+
seq_lens_encoder,
293+
batch_id_per_token,
294+
cu_seqlens_q,
295+
block_tables,
296+
max_seq_len,
297+
speculate_decoder,
298+
stream,
299+
const_cast<paddle::Tensor*>(&kv_cache));
257300
}
258301
}
259302
return {};
260303
}
261304

262-
263305
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
264306
.Inputs({"kv_nope",
265307
"kv_pe",
@@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
268310
"seq_lens_decoder",
269311
"batch_id_per_token",
270312
"cu_seqlens_q",
271-
"block_tables"})
313+
"block_tables",
314+
paddle::Optional("kv_signal_data")})
272315
.Outputs({"kv_cache_out"})
273316
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
274-
.Attrs({"cache_quant_type_str: std::string",
275-
"max_seq_len: int"})
317+
.Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
276318
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
277319

278320
PD_BUILD_STATIC_OP(decode_mla_write_cache)

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
524524
const paddle::Tensor& batch_id_per_token,
525525
const paddle::Tensor& cu_seqlens_q,
526526
const paddle::Tensor& block_tables,
527+
const paddle::optional<paddle::Tensor>& kv_signal_data,
527528
const std::string& cache_quant_type_str,
528529
const int max_seq_len);
529530

0 commit comments

Comments
 (0)