1313// limitations under the License.
1414#pragma once
1515
16- #include " helper.h"
1716#include " mla_cache_kernel.cuh"
17+ #include " helper.h"
18+ #include " remote_cache_kv_ipc.h"
1819
1920template <paddle::DataType T>
2021std::vector<paddle::Tensor> PrefillMLAWriteCache (
21- const AppendAttnMetaData& meta_data,
22- const paddle::Tensor& kv_nope,
23- const paddle::Tensor& kv_pe,
24- const paddle::Tensor& seq_lens,
25- const paddle::Tensor& seq_lens_decoder,
26- const paddle::Tensor& batch_id_per_token,
27- const paddle::Tensor& cu_seqlens_q,
28- const paddle::Tensor& block_tables,
29- const int max_seq_len,
30- cudaStream_t& stream,
31- paddle::Tensor* kv_cache) {
22+ const AppendAttnMetaData& meta_data,
23+ const paddle::Tensor& kv_nope,
24+ const paddle::Tensor& kv_pe,
25+ const paddle::Tensor& seq_lens,
26+ const paddle::Tensor& seq_lens_decoder,
27+ const paddle::Tensor& batch_id_per_token,
28+ const paddle::Tensor& cu_seqlens_q,
29+ const paddle::Tensor& block_tables,
30+ const paddle::optional<paddle::Tensor>& kv_signal_data,
31+ const int max_seq_len,
32+ cudaStream_t& stream,
33+ paddle::Tensor* kv_cache) {
3234 typedef PDTraits<T> traits_;
3335 typedef typename traits_::DataType DataType_;
3436 typedef typename traits_::data_t data_t ;
@@ -50,8 +52,10 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
5052
5153 prefill_absorb_cache_kernel<DataType_, PackSize>
5254 <<<grid_size, blocksize, 0 , stream>>> (
53- reinterpret_cast <DataType_*>(const_cast <data_t *>(kv_nope.data <data_t >())),
54- reinterpret_cast <DataType_*>(const_cast <data_t *>(kv_pe.data <data_t >())),
55+ reinterpret_cast <DataType_*>(
56+ const_cast <data_t *>(kv_nope.data <data_t >())),
57+ reinterpret_cast <DataType_*>(
58+ const_cast <data_t *>(kv_pe.data <data_t >())),
5559 reinterpret_cast <DataType_*>(kv_cache->data <data_t >()),
5660 block_tables.data <int >(),
5761 batch_id_per_token.data <int >(),
@@ -65,6 +69,33 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
6569 pe_size,
6670 block_size,
6771 elem_nums);
72+
73+ const char * fmt_write_cache_completed_signal_str =
74+ std::getenv (" FLAGS_fmt_write_cache_completed_signal" );
75+ const char * FLAGS_use_pd_disaggregation_per_chunk =
76+ std::getenv (" FLAGS_use_pd_disaggregation_per_chunk" );
77+
78+ if (fmt_write_cache_completed_signal_str &&
79+ (std::strcmp (fmt_write_cache_completed_signal_str, " true" ) == 0 ||
80+ std::strcmp (fmt_write_cache_completed_signal_str, " 1" ) == 0 )) {
81+ if (FLAGS_use_pd_disaggregation_per_chunk &&
82+ (std::strcmp (FLAGS_use_pd_disaggregation_per_chunk, " true" ) == 0 ||
83+ std::strcmp (FLAGS_use_pd_disaggregation_per_chunk, " 1" ) == 0 )) {
84+ cudaLaunchHostFunc (
85+ stream,
86+ &(RemoteCacheKvIpc::
87+ save_cache_kv_complete_signal_layerwise_per_query),
88+ (void *)nullptr );
89+ } else {
90+ if (kv_signal_data) {
91+ cudaLaunchHostFunc (
92+ stream,
93+ &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
94+ (void *)(const_cast <int64_t *>(
95+ kv_signal_data.get ().data <int64_t >())));
96+ }
97+ }
98+ }
6899 return {};
69100}
70101
@@ -77,6 +108,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
77108 const paddle::Tensor& batch_id_per_token,
78109 const paddle::Tensor& cu_seqlens_q,
79110 const paddle::Tensor& block_tables,
111+ const paddle::optional<paddle::Tensor>& kv_signal_data,
80112 const std::string& cache_quant_type_str,
81113 const int max_seq_len) {
82114 cudaStream_t stream = kv_pe.stream ();
@@ -85,7 +117,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
85117 const auto & kv_pe_dims = kv_pe.dims ();
86118 const auto & kv_cache_dims = kv_cache.dims ();
87119 meta_data.kv_num_heads = kv_cache_dims[1 ];
88- const auto nope_size = kv_nope_dims[kv_nope_dims.size () - 1 ] / meta_data.kv_num_heads ;
120+ const auto nope_size =
121+ kv_nope_dims[kv_nope_dims.size () - 1 ] / meta_data.kv_num_heads ;
89122 meta_data.token_nums = kv_nope_dims[0 ];
90123 meta_data.head_dims = kv_cache_dims[3 ];
91124 meta_data.head_dims_v = nope_size;
@@ -95,49 +128,53 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
95128 meta_data.batch_size = seq_lens_decoder.dims ()[0 ];
96129 switch (kv_pe.dtype ()) {
97130 case paddle::DataType::BFLOAT16: {
98- return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
99- kv_nope,
100- kv_pe,
101- seq_lens,
102- seq_lens_decoder,
103- batch_id_per_token,
104- cu_seqlens_q,
105- block_tables,
106- max_seq_len,
107- stream,
108- const_cast <paddle::Tensor*>(&kv_cache));
131+ return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
132+ meta_data,
133+ kv_nope,
134+ kv_pe,
135+ seq_lens,
136+ seq_lens_decoder,
137+ batch_id_per_token,
138+ cu_seqlens_q,
139+ block_tables,
140+ kv_signal_data,
141+ max_seq_len,
142+ stream,
143+ const_cast <paddle::Tensor*>(&kv_cache));
109144 }
110145 case paddle::DataType::FLOAT16: {
111- return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
112- kv_nope,
113- kv_pe,
114- seq_lens,
115- seq_lens_decoder,
116- batch_id_per_token,
117- cu_seqlens_q,
118- block_tables,
119- max_seq_len,
120- stream,
121- const_cast <paddle::Tensor*>(&kv_cache));
146+ return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
147+ meta_data,
148+ kv_nope,
149+ kv_pe,
150+ seq_lens,
151+ seq_lens_decoder,
152+ batch_id_per_token,
153+ cu_seqlens_q,
154+ block_tables,
155+ kv_signal_data,
156+ max_seq_len,
157+ stream,
158+ const_cast <paddle::Tensor*>(&kv_cache));
122159 }
123160 }
124161 return {};
125162}
126163
127164template <paddle::DataType T>
128165std::vector<paddle::Tensor> DecodeMLAWriteCache (
129- const AppendAttnMetaData& meta_data,
130- const paddle::Tensor& kv_nope,
131- const paddle::Tensor& kv_pe,
132- const paddle::Tensor& seq_lens,
133- const paddle::Tensor& seq_lens_encoder,
134- const paddle::Tensor& batch_id_per_token,
135- const paddle::Tensor& cu_seqlens_q,
136- const paddle::Tensor& block_tables,
137- const int max_seq_len,
138- const bool speculate_decoder,
139- cudaStream_t& stream,
140- paddle::Tensor* kv_cache) {
166+ const AppendAttnMetaData& meta_data,
167+ const paddle::Tensor& kv_nope,
168+ const paddle::Tensor& kv_pe,
169+ const paddle::Tensor& seq_lens,
170+ const paddle::Tensor& seq_lens_encoder,
171+ const paddle::Tensor& batch_id_per_token,
172+ const paddle::Tensor& cu_seqlens_q,
173+ const paddle::Tensor& block_tables,
174+ const int max_seq_len,
175+ const bool speculate_decoder,
176+ cudaStream_t& stream,
177+ paddle::Tensor* kv_cache) {
141178 typedef PDTraits<T> traits_;
142179 typedef typename traits_::DataType DataType_;
143180 typedef typename traits_::data_t data_t ;
@@ -154,15 +191,16 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
154191 const int blocksize = 128 ;
155192 int grid_size = 1 ;
156193
157-
158194 if (speculate_decoder) {
159195 const uint32_t elem_nums = token_num * kv_num_heads * all_size;
160196 const int pack_num = elem_nums / PackSize;
161197 GetNumBlocks<128 >(pack_num, &grid_size);
162198 speculate_decode_absorb_cache_kernel<DataType_, PackSize>
163199 <<<grid_size, blocksize, 0 , stream>>> (
164- reinterpret_cast <DataType_*>(const_cast <data_t *>(kv_nope.data <data_t >())),
165- reinterpret_cast <DataType_*>(const_cast <data_t *>(kv_pe.data <data_t >())),
200+ reinterpret_cast <DataType_*>(
201+ const_cast <data_t *>(kv_nope.data <data_t >())),
202+ reinterpret_cast <DataType_*>(
203+ const_cast <data_t *>(kv_pe.data <data_t >())),
166204 reinterpret_cast <DataType_*>(kv_cache->data <data_t >()),
167205 block_tables.data <int >(),
168206 batch_id_per_token.data <int >(),
@@ -182,8 +220,10 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
182220 GetNumBlocks<128 >(pack_num, &grid_size);
183221 decode_absorb_cache_kernel<DataType_, PackSize>
184222 <<<grid_size, blocksize, 0 , stream>>> (
185- reinterpret_cast <DataType_*>(const_cast <data_t *>(kv_nope.data <data_t >())),
186- reinterpret_cast <DataType_*>(const_cast <data_t *>(kv_pe.data <data_t >())),
223+ reinterpret_cast <DataType_*>(
224+ const_cast <data_t *>(kv_nope.data <data_t >())),
225+ reinterpret_cast <DataType_*>(
226+ const_cast <data_t *>(kv_pe.data <data_t >())),
187227 reinterpret_cast <DataType_*>(kv_cache->data <data_t >()),
188228 block_tables.data <int >(),
189229 cu_seqlens_q.data <int >(),
@@ -218,7 +258,8 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
218258 const auto & kv_pe_dims = kv_pe.dims ();
219259 const auto & kv_cache_dims = kv_cache.dims ();
220260 meta_data.kv_num_heads = kv_cache_dims[1 ];
221- const auto nope_size = kv_nope_dims[kv_nope_dims.size () - 1 ] / meta_data.kv_num_heads ;
261+ const auto nope_size =
262+ kv_nope_dims[kv_nope_dims.size () - 1 ] / meta_data.kv_num_heads ;
222263 meta_data.token_nums = kv_nope_dims[0 ];
223264 meta_data.head_dims = kv_cache_dims[3 ];
224265 meta_data.head_dims_v = nope_size;
@@ -228,38 +269,39 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
228269 meta_data.batch_size = seq_lens_encoder.dims ()[0 ];
229270 switch (kv_pe.dtype ()) {
230271 case paddle::DataType::BFLOAT16: {
231- return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
232- kv_nope,
233- kv_pe,
234- seq_lens,
235- seq_lens_encoder,
236- batch_id_per_token,
237- cu_seqlens_q,
238- block_tables,
239- max_seq_len,
240- speculate_decoder,
241- stream,
242- const_cast <paddle::Tensor*>(&kv_cache));
272+ return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
273+ meta_data,
274+ kv_nope,
275+ kv_pe,
276+ seq_lens,
277+ seq_lens_encoder,
278+ batch_id_per_token,
279+ cu_seqlens_q,
280+ block_tables,
281+ max_seq_len,
282+ speculate_decoder,
283+ stream,
284+ const_cast <paddle::Tensor*>(&kv_cache));
243285 }
244286 case paddle::DataType::FLOAT16: {
245- return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
246- kv_nope,
247- kv_pe,
248- seq_lens,
249- seq_lens_encoder,
250- batch_id_per_token,
251- cu_seqlens_q,
252- block_tables,
253- max_seq_len,
254- speculate_decoder,
255- stream,
256- const_cast <paddle::Tensor*>(&kv_cache));
287+ return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
288+ meta_data,
289+ kv_nope,
290+ kv_pe,
291+ seq_lens,
292+ seq_lens_encoder,
293+ batch_id_per_token,
294+ cu_seqlens_q,
295+ block_tables,
296+ max_seq_len,
297+ speculate_decoder,
298+ stream,
299+ const_cast <paddle::Tensor*>(&kv_cache));
257300 }
258301 }
259302 return {};
260303}
261304
262-
263305PD_BUILD_STATIC_OP (prefill_mla_write_cache)
264306 .Inputs({" kv_nope" ,
265307 " kv_pe" ,
@@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
268310 " seq_lens_decoder" ,
269311 " batch_id_per_token" ,
270312 " cu_seqlens_q" ,
271- " block_tables" })
313+ " block_tables" ,
314+ paddle::Optional (" kv_signal_data" )})
272315 .Outputs({" kv_cache_out" })
273316 .SetInplaceMap({{" kv_cache" , " kv_cache_out" }})
274- .Attrs({" cache_quant_type_str: std::string" ,
275- " max_seq_len: int" })
317+ .Attrs({" cache_quant_type_str: std::string" , " max_seq_len: int" })
276318 .SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
277319
278320PD_BUILD_STATIC_OP (decode_mla_write_cache)
0 commit comments