diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu index ad501752abb..b582c862c38 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu @@ -13,22 +13,24 @@ // limitations under the License. #pragma once -#include "helper.h" #include "mla_cache_kernel.cuh" +#include "helper.h" +#include "remote_cache_kv_ipc.h" template std::vector PrefillMLAWriteCache( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& kv_nope, - const paddle::Tensor& kv_pe, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const int max_seq_len, - cudaStream_t& stream, - paddle::Tensor* kv_cache) { + const AppendAttnMetaData& meta_data, + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& kv_signal_data, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -50,8 +52,10 @@ std::vector PrefillMLAWriteCache( prefill_absorb_cache_kernel <<>>( - reinterpret_cast(const_cast(kv_nope.data())), - reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast( + const_cast(kv_nope.data())), + reinterpret_cast( + const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), batch_id_per_token.data(), @@ -65,6 +69,33 @@ std::vector PrefillMLAWriteCache( pe_size, block_size, elem_nums); + + const char* fmt_write_cache_completed_signal_str = + std::getenv("FLAGS_fmt_write_cache_completed_signal"); + const char* FLAGS_use_pd_disaggregation_per_chunk = + std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); + + if (fmt_write_cache_completed_signal_str && + (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || + std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { + if (FLAGS_use_pd_disaggregation_per_chunk && + (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || + std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { + cudaLaunchHostFunc( + stream, + &(RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_per_query), + (void*)nullptr); + } else { + if (kv_signal_data) { + cudaLaunchHostFunc( + stream, + &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, + (void*)(const_cast( + kv_signal_data.get().data()))); + } + } + } return {}; } @@ -77,6 +108,7 @@ std::vector PrefillMLAWriteCacheKernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, + const paddle::optional& kv_signal_data, const std::string& cache_quant_type_str, const int max_seq_len) { cudaStream_t stream = kv_pe.stream(); @@ -85,7 +117,8 @@ std::vector PrefillMLAWriteCacheKernel( const auto& kv_pe_dims = kv_pe.dims(); const auto& kv_cache_dims = kv_cache.dims(); meta_data.kv_num_heads = kv_cache_dims[1]; - const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; + const auto nope_size = + kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; meta_data.token_nums = kv_nope_dims[0]; meta_data.head_dims = kv_cache_dims[3]; meta_data.head_dims_v = nope_size; @@ -95,30 +128,34 @@ std::vector PrefillMLAWriteCacheKernel( meta_data.batch_size = seq_lens_decoder.dims()[0]; switch (kv_pe.dtype()) { case paddle::DataType::BFLOAT16: { - return PrefillMLAWriteCache(meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - stream, - const_cast(&kv_cache)); + return PrefillMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_decoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + kv_signal_data, + max_seq_len, + stream, + const_cast(&kv_cache)); } case paddle::DataType::FLOAT16: { - return PrefillMLAWriteCache(meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - stream, - const_cast(&kv_cache)); + return PrefillMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_decoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + kv_signal_data, + max_seq_len, + stream, + const_cast(&kv_cache)); } } return {}; @@ -126,18 +163,18 @@ std::vector PrefillMLAWriteCacheKernel( template std::vector DecodeMLAWriteCache( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& kv_nope, - const paddle::Tensor& kv_pe, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const int max_seq_len, - const bool speculate_decoder, - cudaStream_t& stream, - paddle::Tensor* kv_cache) { + const AppendAttnMetaData& meta_data, + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const int max_seq_len, + const bool speculate_decoder, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -154,15 +191,16 @@ std::vector DecodeMLAWriteCache( const int blocksize = 128; int grid_size = 1; - if (speculate_decoder) { const uint32_t elem_nums = token_num * kv_num_heads * all_size; const int pack_num = elem_nums / PackSize; GetNumBlocks<128>(pack_num, &grid_size); speculate_decode_absorb_cache_kernel <<>>( - reinterpret_cast(const_cast(kv_nope.data())), - reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast( + const_cast(kv_nope.data())), + reinterpret_cast( + const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), batch_id_per_token.data(), @@ -182,8 +220,10 @@ std::vector DecodeMLAWriteCache( GetNumBlocks<128>(pack_num, &grid_size); decode_absorb_cache_kernel <<>>( - reinterpret_cast(const_cast(kv_nope.data())), - reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast( + const_cast(kv_nope.data())), + reinterpret_cast( + const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), cu_seqlens_q.data(), @@ -218,7 +258,8 @@ std::vector DecodeMLAWriteCacheKernel( const auto& kv_pe_dims = kv_pe.dims(); const auto& kv_cache_dims = kv_cache.dims(); meta_data.kv_num_heads = kv_cache_dims[1]; - const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; + const auto nope_size = + kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; meta_data.token_nums = kv_nope_dims[0]; meta_data.head_dims = kv_cache_dims[3]; meta_data.head_dims_v = nope_size; @@ -228,38 +269,39 @@ std::vector DecodeMLAWriteCacheKernel( meta_data.batch_size = seq_lens_encoder.dims()[0]; switch (kv_pe.dtype()) { case paddle::DataType::BFLOAT16: { - return DecodeMLAWriteCache(meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - speculate_decoder, - stream, - const_cast(&kv_cache)); + return DecodeMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + max_seq_len, + speculate_decoder, + stream, + const_cast(&kv_cache)); } case paddle::DataType::FLOAT16: { - return DecodeMLAWriteCache(meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - speculate_decoder, - stream, - const_cast(&kv_cache)); + return DecodeMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + max_seq_len, + speculate_decoder, + stream, + const_cast(&kv_cache)); } } return {}; } - PD_BUILD_STATIC_OP(prefill_mla_write_cache) .Inputs({"kv_nope", "kv_pe", @@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache) "seq_lens_decoder", "batch_id_per_token", "cu_seqlens_q", - "block_tables"}) + "block_tables", + paddle::Optional("kv_signal_data")}) .Outputs({"kv_cache_out"}) .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) - .Attrs({"cache_quant_type_str: std::string", - "max_seq_len: int"}) + .Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"}) .SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel)); PD_BUILD_STATIC_OP(decode_mla_write_cache) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 93dbaad2d78..abf16db95c9 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -527,6 +527,7 @@ std::vector PrefillMLAWriteCacheKernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, + const paddle::optional& kv_signal_data, const std::string& cache_quant_type_str, const int max_seq_len); diff --git a/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh b/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh index 2d55d91e5eb..92b25aceef2 100644 --- a/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh +++ b/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh @@ -13,8 +13,8 @@ // limitations under the License. /* - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri - * Dao. Licensed under the BSD 3-Clause. + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, + * Pradeep Ramani, Tri Dao. Licensed under the BSD 3-Clause. * * Modified by the FlashInfer team. */ @@ -39,8 +39,8 @@ #include "epilogue.cuh" #include "helper.h" #include "kernel_traits.cuh" -#include "mainloop_mma.cuh" #include "mainloop_load.cuh" +#include "mainloop_mma.cuh" #include "utils.cuh" #ifdef DEBUG_MLA @@ -52,76 +52,91 @@ namespace mla_attn { using namespace cute; -template +template struct Params { - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using IdType = IdType_; - - alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head] - alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head] - alignas(16) DTypeO *O; // [token_num, head_num, dim_head] - alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head] - alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num] - alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num] - - alignas(16) IdType *block_tables; - alignas(16) IdType *seq_lens_this_time; - alignas(16) IdType *seq_lens_decoder; - alignas(16) IdType *cumsum_q_seqlens; - alignas(16) IdType *batch_id_per_token; - - alignas(16) IdType *batch_ids; - alignas(16) IdType *tile_ids_per_batch; - alignas(16) IdType *num_blocks_x; - alignas(16) IdType *chunk_size_device; - - uint32_t q_stride_bsz; - uint32_t q_stride_head_num; - - uint32_t kv_stride_block_num; - uint32_t kv_stride_block_size; - - uint32_t o_stride_bsz; - uint32_t o_stride_head_num; - - int bsz; - int token_num; - int max_block_num; - int max_block_num_per_seq; - int q_num_head; - int qk_head_dim; - int vo_head_dim; - int block_size; - int max_draft_token_num; - int chunk_num; - - float sm_scale; + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + + alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head] + alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head] + alignas(16) DTypeO *O; // [token_num, head_num, dim_head] + alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head] + alignas( + 16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num] + alignas( + 16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num] + + alignas(16) IdType *block_tables; + alignas(16) IdType *seq_lens_this_time; + alignas(16) IdType *seq_lens_decoder; + alignas(16) IdType *cumsum_q_seqlens; + alignas(16) IdType *batch_id_per_token; + + alignas(16) IdType *batch_ids; + alignas(16) IdType *tile_ids_per_batch; + alignas(16) IdType *num_blocks_x; + alignas(16) IdType *chunk_size_device; + + uint32_t q_stride_bsz; + uint32_t q_stride_head_num; + + uint32_t kv_stride_block_num; + uint32_t kv_stride_block_size; + + uint32_t o_stride_bsz; + uint32_t o_stride_head_num; + + int bsz; + int token_num; + int max_block_num; + int max_block_num_per_seq; + int q_num_head; + int qk_head_dim; + int vo_head_dim; + int block_size; + int max_draft_token_num; + int chunk_num; + + float sm_scale; }; -#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 8) { \ - constexpr size_t GROUP_SIZE = 8; \ - __VA_ARGS__ \ - } else if (group_size == 16) { \ - constexpr size_t GROUP_SIZE = 16; \ - __VA_ARGS__ \ - } else if (group_size == 64) { \ - constexpr size_t GROUP_SIZE = 64; \ - __VA_ARGS__ \ - } else { \ - PD_THROW("not support the group_size: ", group_size); \ - return cudaErrorNotSupported; \ +#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else if (group_size == 64) { \ + constexpr size_t GROUP_SIZE = 64; \ + __VA_ARGS__ \ + } else if (group_size == 128) { \ + constexpr size_t GROUP_SIZE = 128; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size: ", group_size); \ + return cudaErrorNotSupported; \ } -template -__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1) -MLAWithKVCacheKernel(CUTE_GRID_CONSTANT - typename CollectiveMainloop::Params const mainloop_params, - CUTE_GRID_CONSTANT - typename CollectiveEpilogue::Params const epilogue_params) { - +template +__global__ void __launch_bounds__( + Ktraits::NUM_WARPS *cutlass::NumThreadsPerWarp, 1) + MLAWithKVCacheKernel( + CUTE_GRID_CONSTANT + typename CollectiveMainloop::Params const mainloop_params, + CUTE_GRID_CONSTANT + typename CollectiveEpilogue::Params const epilogue_params) { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using DTypeO = typename Ktraits::DTypeO; @@ -147,7 +162,8 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT using PipelineStateQ = typename MainloopPipelineQ::PipelineState; extern __shared__ char shared_memory[]; - auto& shared_storage = *reinterpret_cast(shared_memory); + auto &shared_storage = + *reinterpret_cast(shared_memory); int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); @@ -158,12 +174,14 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT } // Obtain warp index - int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + int const warp_group_thread_idx = + threadIdx.x % cutlass::NumThreadsPerWarpGroup; PipelineParams pipeline_params; int warp_group_idx = cutlass::canonical_warp_group_idx(); - pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer - : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params.role = warp_group_idx == 0 + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; if constexpr (use_tma_load_kv) { pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NUM_MMA_THREADS; @@ -173,17 +191,20 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT } PipelineParamsQ pipeline_params_q; - pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer - : MainloopPipelineQ::ThreadCategory::Consumer; + pipeline_params_q.role = warp_group_idx == 0 + ? MainloopPipelineQ::ThreadCategory::Producer + : MainloopPipelineQ::ThreadCategory::Consumer; pipeline_params_q.producer_arv_count = NUM_COPY_THREADS; - pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk - + pipeline_params_q.consumer_arv_count = + cutlass::NumThreadsPerWarpGroup; // just one wg qk MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q); MainloopPipeline pipeline_kv = [&] { if constexpr (use_tma_load_kv) { - pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV; - return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params, + pipeline_params.transaction_bytes = + CollectiveMainloop::TmaTransactionBytesKV; + return MainloopPipeline(shared_storage.pipeline_kv, + pipeline_params, /*cluster_shape=*/Shape<_1, _1, _1>{}); } else { return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params); @@ -196,191 +217,217 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT if (warp_group_idx == 0) { // producer - if constexpr(USE_REG_EALLOC) { + if constexpr (USE_REG_EALLOC) { cutlass::arch::warpgroup_reg_dealloc<72>(); } - const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0); + const uint32_t warp_idx_in_warpgroup = + __shfl_sync(0xffffffff, warp_idx % 4, 0); - PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state(); - PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state(); + PipelineStateQ smem_pipe_write_q = + cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_kv = + cutlass::make_producer_start_state(); for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { const int bid = mainloop_params.batch_ids[i]; const int tile_id = mainloop_params.tile_ids_per_batch[i]; const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; - const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; + const int seq_len_decoder_now = + mainloop_params.seq_lens_decoder[bid] + seq_len_now; const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; - cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, - /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); + cutlass::arch::NamedBarrier::sync( + Ktraits::NUM_THREADS, + /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); // load Q - collective_mainloop.load_q( - mainloop_params, - pipeline_q, - smem_pipe_write_q, - shared_storage, - threadIdx.x, - bid); + collective_mainloop.load_q(mainloop_params, + pipeline_q, + smem_pipe_write_q, + shared_storage, + threadIdx.x, + bid); if constexpr (!use_tma_load_kv) { // load kv - collective_mainloop.load_kv( - mainloop_params, - pipeline_kv, - smem_pipe_write_kv, - shared_storage, - bid, - seq_len_decoder_now, - tile_id - ); + collective_mainloop.load_kv(mainloop_params, + pipeline_kv, + smem_pipe_write_kv, + shared_storage, + bid, + seq_len_decoder_now, + tile_id); } else { if (warp_idx_in_warpgroup == 0) { // load kv tma - collective_mainloop.load_kv_tma( - mainloop_params, - pipeline_kv, - smem_pipe_write_kv, - shared_storage, - bid, - seq_len_decoder_now, - tile_id - ); + collective_mainloop.load_kv_tma(mainloop_params, + pipeline_kv, + smem_pipe_write_kv, + shared_storage, + bid, + seq_len_decoder_now, + tile_id); } } } } else { // consumer - if constexpr(USE_REG_EALLOC) { + if constexpr (USE_REG_EALLOC) { cutlass::arch::warpgroup_reg_alloc<216>(); } PipelineStateQ smem_pipe_read_q; PipelineState smem_pipe_read_kv; typename Ktraits::TiledMmaPVSS tiled_mma_pv; - Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{})); + Tensor tOrO = + partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{})); - auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale); + auto attention_updater = + OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>( + mainloop_params.sm_scale); for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { clear(tOrO); clear(attention_updater.scores_scale); const int bid = mainloop_params.batch_ids[i]; const int tile_id = mainloop_params.tile_ids_per_batch[i]; const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; - const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; + const int seq_len_decoder_now = + mainloop_params.seq_lens_decoder[bid] + seq_len_now; const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; - cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, - /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); + cutlass::arch::NamedBarrier::sync( + Ktraits::NUM_THREADS, + /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); if constexpr (BLOCK_SHAPE_KV == 64) { - mma_f16( - mainloop_params, - pipeline_q, - smem_pipe_read_q, - pipeline_kv, - smem_pipe_read_kv, - tOrO, - attention_updater, - threadIdx.x - NUM_COPY_THREADS, - bid, - seq_len_decoder_now, - seq_len_now, - tile_id, - shared_storage); + mma_f16(mainloop_params, + pipeline_q, + smem_pipe_read_q, + pipeline_kv, + smem_pipe_read_kv, + tOrO, + attention_updater, + threadIdx.x - NUM_COPY_THREADS, + bid, + seq_len_decoder_now, + seq_len_now, + tile_id, + shared_storage); } else if (BLOCK_SHAPE_KV == 32) { - mma_f16_two_stages( - mainloop_params, - pipeline_q, - smem_pipe_read_q, - pipeline_kv, - smem_pipe_read_kv, - tOrO, - attention_updater, - threadIdx.x - NUM_COPY_THREADS, - bid, - seq_len_decoder_now, - seq_len_now, - tile_id, - shared_storage); + mma_f16_two_stages(mainloop_params, + pipeline_q, + smem_pipe_read_q, + pipeline_kv, + smem_pipe_read_kv, + tOrO, + attention_updater, + threadIdx.x - NUM_COPY_THREADS, + bid, + seq_len_decoder_now, + seq_len_now, + tile_id, + shared_storage); } - collective_epilogue.store( - epilogue_params, - tOrO, - attention_updater.get_lse(), - shared_storage, - tiled_mma_pv, - threadIdx.x - NUM_COPY_THREADS, - bid, - mainloop_params.bsz, - seq_len_now, - start_token_idx, - tile_id, - seq_len_decoder_now, - chunk_size, - mainloop_params.max_draft_token_num, - mainloop_params.o_stride_bsz); - } + collective_epilogue.store(epilogue_params, + tOrO, + attention_updater.get_lse(), + shared_storage, + tiled_mma_pv, + threadIdx.x - NUM_COPY_THREADS, + bid, + mainloop_params.bsz, + seq_len_now, + start_token_idx, + tile_id, + seq_len_decoder_now, + chunk_size, + mainloop_params.max_draft_token_num, + mainloop_params.o_stride_bsz); + } } } - -template -cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, - cudaStream_t stream) { +template +cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched( + Params ¶ms, cudaStream_t stream) { using DTypeQ = typename KernelTraits::DTypeQ; using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; using IdType = typename KernelTraits::IdType; using NV_TYPE = typename KernelTraits::NV_TYPE; - using CollectiveMainloop = - CollectiveMainloop; + using CollectiveMainloop = CollectiveMainloop; using CollectiveEpilogue = CollectiveEpilogue; - typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ - make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q - make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)), - make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})), - params.Q, - params.KV, - params.m, - params.d, - params.block_tables, - params.seq_lens_this_time, - params.seq_lens_decoder, - params.cumsum_q_seqlens, - params.batch_ids, - params.tile_ids_per_batch, - params.num_blocks_x, - params.chunk_size_device, - params.sm_scale, - params.bsz, - params.max_block_num, - params.max_block_num_per_seq, - params.q_stride_bsz, - params.q_stride_head_num, - params.kv_stride_block_num, - params.kv_stride_block_size, - params.o_stride_bsz, - params.o_stride_head_num, - params.chunk_num, - params.max_draft_token_num - }); - typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({ - params.O, - make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O - params.O_tmp, - make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp - }); + typename CollectiveMainloop::Params mainloop_params = + CollectiveMainloop::to_underlying_arguments( + {make_layout( + make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), + make_stride(params.qk_head_dim, _1{})), // layout q + make_layout( + make_shape( + params.block_size, params.qk_head_dim, params.max_block_num), + make_stride(params.qk_head_dim, + _1{}, + params.block_size * params.qk_head_dim)), + make_layout(make_shape(params.chunk_num, + params.bsz * params.max_draft_token_num * + params.q_num_head), + make_stride(params.bsz * params.max_draft_token_num * + params.q_num_head, + _1{})), + params.Q, + params.KV, + params.m, + params.d, + params.block_tables, + params.seq_lens_this_time, + params.seq_lens_decoder, + params.cumsum_q_seqlens, + params.batch_ids, + params.tile_ids_per_batch, + params.num_blocks_x, + params.chunk_size_device, + params.sm_scale, + params.bsz, + params.max_block_num, + params.max_block_num_per_seq, + params.q_stride_bsz, + params.q_stride_head_num, + params.kv_stride_block_num, + params.kv_stride_block_size, + params.o_stride_bsz, + params.o_stride_head_num, + params.chunk_num, + params.max_draft_token_num}); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments_ntma({ + params.O, + make_layout( + make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), + make_stride(params.vo_head_dim, _1{})), // layout O + params.O_tmp, + make_layout( + make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), + make_stride(params.vo_head_dim, _1{})) // layout O_tmp + }); // Get the ptr to kernel function. - auto kernel = - MLAWithKVCacheKernel; + auto kernel = MLAWithKVCacheKernel; int smem_size = sizeof(typename KernelTraits::SharedStorage); - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); int device; cudaGetDevice(&device); int multiprocessor_count; - cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device); + cudaDeviceGetAttribute( + &multiprocessor_count, cudaDevAttrMultiProcessorCount, device); int act_blocks_per_sm; cudaOccupancyMaxActiveBlocksPerMultiprocessor( &act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size); @@ -390,15 +437,15 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, dim3 grid_dims = {multiprocessor_count, 1, 1}; static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; dim3 block_dims(ctaSize, 1, 1); - kernel<<>>( - mainloop_params, epilogue_params - ); + kernel<<>>(mainloop_params, + epilogue_params); if (params.chunk_num > 1) { constexpr int vec_size = 16 / sizeof(DTypeO); constexpr int merge_block_size = 256; constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size; constexpr int blocky = (merge_block_size + blockx - 1) / blockx; - dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large + dim3 grids_merge(multiprocessor_count, + params.q_num_head); // 128k is too large dim3 blocks_merge(blockx, blocky); merge_multi_chunks_kernel -cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) { +template +cudaError_t BatchMLAWithPagedKVCacheDispatched(Params ¶ms, + cudaStream_t stream) { constexpr bool CAUSAL = true; if constexpr (HEAD_DIM_QK == 576) { - DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE, - BatchMLAWithPagedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - CAUSAL, - Params, - USE_REG_EALLOC, - USE_FIXED_BLOCK>(params, stream);) + DISPATCH_GROUP_SIZE(params.q_num_head, + GROUP_SIZE, + BatchMLAWithPagedKVCacheKernelTraitsDispatched< + AttentionKernelTraits, + CAUSAL, + Params, + USE_REG_EALLOC, + USE_FIXED_BLOCK>(params, stream);) } else { return cudaErrorNotSupported; } diff --git a/examples/splitwise/stop.sh b/examples/splitwise/stop.sh index 943efa12c58..19bee0f360b 100644 --- a/examples/splitwise/stop.sh +++ b/examples/splitwise/stop.sh @@ -1,6 +1,7 @@ pkill -9 -f python pkill -9 -f fastdeploy pkill -9 -f gunicorn -pkill -9 -f redis-server +# Kill redis-server if you need. +#pkill -9 -f redis-server sleep 1 diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 23b6b72f3e3..9c1c86b2532 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -159,16 +159,22 @@ def __init__( cache_v = [] self.messager = {} for layer_idx in range(self.num_layers): + # value cache + val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}" + if val_cache_key in self.gpu_cache_kvs: + val_cache = self.gpu_cache_kvs[val_cache_key] + cache_v.append(val_cache) + if paddle.is_compiled_with_xpu(): + cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr())) + else: + cache_v_ptr_list.append(val_cache.data_ptr()) + # key cache key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] - val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] cache_k.append(key_cache) - cache_v.append(val_cache) if paddle.is_compiled_with_xpu(): cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr())) - cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr())) else: cache_k_ptr_list.append(key_cache.data_ptr()) - cache_v_ptr_list.append(val_cache.data_ptr()) cache_k_ptr_list = np.array(cache_k_ptr_list) cache_v_ptr_list = np.array(cache_v_ptr_list) @@ -198,7 +204,6 @@ def __init__( elif protocol == "rdma": logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}") - self.messager[protocol] = RDMACommManager( splitwise_role, rank, @@ -460,16 +465,22 @@ def __init__( cache_v = [] self.messager = {} for layer_idx in range(self.num_layers): + # value cache + val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}" + if val_cache_key in self.gpu_cache_kvs: + val_cache = self.gpu_cache_kvs[val_cache_key] + cache_v.append(val_cache) + if paddle.is_compiled_with_xpu(): + cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr())) + else: + cache_v_ptr_list.append(val_cache.data_ptr()) + # key cache key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] - val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] cache_k.append(key_cache) - cache_v.append(val_cache) if paddle.is_compiled_with_xpu(): cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr())) - cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr())) else: cache_k_ptr_list.append(key_cache.data_ptr()) - cache_v_ptr_list.append(val_cache.data_ptr()) cache_k_ptr_list = np.array(cache_k_ptr_list) cache_v_ptr_list = np.array(cache_v_ptr_list) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index a91e1061cde..060dd1bc76c 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -245,6 +245,15 @@ def launch_cache_manager( log_dir = envs.FD_LOG_DIR cache_manager_processes = [] visible_devices = get_all_visible_devices() + + val_cache_arg_str = "" + if val_cache_shape: + if isinstance(val_cache_shape, list): + val_shape_str = ",".join(map(str, val_cache_shape)) + else: + val_shape_str = str(val_cache_shape) + val_cache_arg_str = f" --value_cache_shape {val_shape_str}" + for i in range(tensor_parallel_size): launch_cmd = ( "FLAGS_allocator_strategy=auto_growth " @@ -259,7 +268,7 @@ def launch_cache_manager( + f" --mp_num {tensor_parallel_size}" + f" --cache_dtype {cache_config.cache_dtype}" + f" --key_cache_shape {key_cache_shape}" - + f" --value_cache_shape {val_cache_shape}" + + val_cache_arg_str + f" --cache_queue_port {cache_config.cache_queue_port}" + f" --enable_splitwise {int(self.enable_splitwise)}" + f" --pod_ip {pod_ip}" @@ -332,6 +341,15 @@ def launch_cache_messager( log_dir = envs.FD_LOG_DIR cache_messager_processes = [] visible_devices = get_all_visible_devices() + + val_cache_arg_str = "" + if value_cache_shape: + if isinstance(value_cache_shape, list): + val_shape_str = ",".join(map(str, value_cache_shape)) + else: + val_shape_str = str(value_cache_shape) + val_cache_arg_str = f" --value_cache_shape {val_shape_str}" + for i in range(tensor_parallel_size): launch_cmd = ( "FLAGS_allocator_strategy=auto_growth " @@ -345,7 +363,7 @@ def launch_cache_messager( + f" --mp_num {tensor_parallel_size}" + f" --cache_dtype {cache_config.cache_dtype}" + f" --key_cache_shape {key_cache_shape}" - + f" --value_cache_shape {value_cache_shape}" + + val_cache_arg_str + f" --pod_ip {pod_ip}" + f" --cache_queue_port {cache_config.cache_queue_port}" + f" --engine_worker_queue_port {engine_worker_queue_port}" diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h index d9b442a0a5a..2a7f498add9 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h @@ -198,8 +198,8 @@ int get_port_info(struct ibv_context* Context, int parse_port_ib_info(); // Memory region exchange -bool client_exchange_mr(struct RdmaContext* ctx); -bool server_exchange_mr(struct RdmaContext* ctx); +bool client_exchange_mr(struct RdmaContext* ctx, bool has_value_cache); +bool server_exchange_mr(struct RdmaContext* ctx, bool has_value_cache); bool server_send_memory_region(struct RdmaContext* ctx, void* local_mr, int byte_num); diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h index 3a5b2dc7883..35de5fb450b 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h @@ -149,6 +149,7 @@ class RDMACommunicator { struct ibv_pd* g_pd = NULL; // fd int RDMACommunicator_status; // Communicator status flag bool start_client_listener = false; // Client listener flag + bool has_value_cache_; // MLA does not have value cache. }; #endif // KVCACHE_RDMA_H diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp index 8e9ec468e35..0bdc40fae47 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp @@ -712,8 +712,8 @@ bool exchange_mr_vector(struct RdmaContext *ctx, * @param ctx The RDMA context * @return true on success, false on failure */ -bool client_exchange_mr(struct RdmaContext *ctx) { - LOGD("verb client exchange mr: start"); +bool client_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) { + LOGD("verb client exchange mr: start. has_value_cache=%d", has_value_cache); if (ctx->conn.layer_number <= 0) { ERR("Invalid layer number: %d", ctx->conn.layer_number); @@ -723,19 +723,27 @@ bool client_exchange_mr(struct RdmaContext *ctx) { auto layer_num = ctx->conn.layer_number; std::vector key_ptrs(layer_num); std::vector key_rkeys(layer_num); - std::vector val_ptrs(layer_num); - std::vector val_rkeys(layer_num); + std::vector val_ptrs; + std::vector val_rkeys; + if (has_value_cache) { + val_ptrs.resize(layer_num); + val_rkeys.resize(layer_num); + } if (!exchange_mr_vector(ctx, key_ptrs, true)) return false; if (!exchange_mr_vector(ctx, key_rkeys, true)) return false; - if (!exchange_mr_vector(ctx, val_ptrs, true)) return false; - if (!exchange_mr_vector(ctx, val_rkeys, true)) return false; + if (has_value_cache) { + if (!exchange_mr_vector(ctx, val_ptrs, true)) return false; + if (!exchange_mr_vector(ctx, val_rkeys, true)) return false; + } for (int i = 0; i < layer_num; ++i) { ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]); ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]); - ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]); - ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]); + if (has_value_cache) { + ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]); + ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]); + } } return true; } @@ -746,8 +754,8 @@ bool client_exchange_mr(struct RdmaContext *ctx) { * @param ctx The RDMA context * @return true on success, false on failure */ -bool server_exchange_mr(struct RdmaContext *ctx) { - LOGD("verbs server exchange mr: start"); +bool server_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) { + LOGD("verbs server exchange mr: start. has_value_cache=%d", has_value_cache); if (ctx->conn.layer_number <= 0) { ERR("Invalid layer number: %d", ctx->conn.layer_number); @@ -759,8 +767,16 @@ bool server_exchange_mr(struct RdmaContext *ctx) { auto &val_mrs = ctx->conn.write_cache_value_server_mr_list; // Verify that server memory regions are properly initialized - if (key_mrs.size() != layer_num || val_mrs.size() != layer_num) { - ERR("server write cache memory region size error"); + if (key_mrs.size() != layer_num) { + ERR("server write cache KEY memory region size error: %zu vs %d", + key_mrs.size(), + layer_num); + return false; + } + if (has_value_cache && val_mrs.size() != layer_num) { + ERR("server write cache VALUE memory region size error: %zu vs %d", + val_mrs.size(), + layer_num); return false; } @@ -772,22 +788,27 @@ bool server_exchange_mr(struct RdmaContext *ctx) { send_key_ptrs.reserve(layer_num); send_key_rkeys.reserve(layer_num); - send_val_ptrs.reserve(layer_num); - send_val_rkeys.reserve(layer_num); + if (has_value_cache) { + send_val_ptrs.reserve(layer_num); + send_val_rkeys.reserve(layer_num); + } // Collect memory region information from local MRs for (int i = 0; i < layer_num; ++i) { send_key_ptrs.push_back(reinterpret_cast(key_mrs[i]->addr)); send_key_rkeys.push_back(key_mrs[i]->rkey); - send_val_ptrs.push_back(reinterpret_cast(val_mrs[i]->addr)); - send_val_rkeys.push_back(val_mrs[i]->rkey); + if (has_value_cache) { + send_val_ptrs.push_back(reinterpret_cast(val_mrs[i]->addr)); + send_val_rkeys.push_back(val_mrs[i]->rkey); + } } - // Send all vectors to client if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false; if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false; - if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false; - if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false; + if (has_value_cache) { + if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false; + if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false; + } return true; } diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp index 60f06bf06db..dd623d23b53 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp @@ -78,6 +78,18 @@ RDMACommunicator::RDMACommunicator(std::string& role, throw std::runtime_error("Invalid layer number"); } + if (local_cache_value_ptr_layer_head_.empty()) { + has_value_cache_ = false; + WARN( + "Value Cache is empty (Maybe MLA Model). RDMA will run in Key-Only " + "mode."); + } else { + has_value_cache_ = true; + if (local_cache_value_ptr_layer_head_.size() != layer_number) { + throw std::runtime_error("Key and Value cache layer number mismatch!"); + } + } + // Step 2: Setup cache vectors and pointers resize_vectors(); assign_pointers(); @@ -100,7 +112,6 @@ RDMACommunicator::RDMACommunicator(std::string& role, }); server_thread.detach(); } - RDMACommunicator_status = 1; INFO("RDMA communicator initialized successfully"); } catch (const std::exception& e) { @@ -119,7 +130,9 @@ void RDMACommunicator::resize_vectors() { } local_cache_key_ptr_per_layer.resize(layer_number); - local_cache_value_ptr_per_layer.resize(layer_number); + if (has_value_cache_) { + local_cache_value_ptr_per_layer.resize(layer_number); + } } void RDMACommunicator::assign_pointers() { @@ -131,15 +144,19 @@ void RDMACommunicator::assign_pointers() { // Assign pointers for each layer and block for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) { // Validate layer head pointers - if (local_cache_key_ptr_layer_head_[layer_idx] == 0 || - local_cache_value_ptr_layer_head_[layer_idx] == 0) { + if (local_cache_key_ptr_layer_head_[layer_idx] == 0) { throw std::runtime_error("Invalid cache pointer for layer " + std::to_string(layer_idx)); } - - // Resize block vectors for current layer local_cache_key_ptr_per_layer[layer_idx].resize(block_number); - local_cache_value_ptr_per_layer[layer_idx].resize(block_number); + + if (has_value_cache_) { + if (local_cache_value_ptr_layer_head_[layer_idx] == 0) { + throw std::runtime_error("Invalid VALUE cache pointer for layer " + + std::to_string(layer_idx)); + } + local_cache_value_ptr_per_layer[layer_idx].resize(block_number); + } // Calculate and assign block pointers for (int block_idx = 0; block_idx < block_number; ++block_idx) { @@ -147,9 +164,12 @@ void RDMACommunicator::assign_pointers() { reinterpret_cast(local_cache_key_ptr_layer_head_[layer_idx] + block_idx * block_size_byte); - local_cache_value_ptr_per_layer[layer_idx][block_idx] = - reinterpret_cast(local_cache_value_ptr_layer_head_[layer_idx] + - block_idx * block_size_byte); + if (has_value_cache_) { + local_cache_value_ptr_per_layer[layer_idx][block_idx] = + reinterpret_cast( + local_cache_value_ptr_layer_head_[layer_idx] + + block_idx * block_size_byte); + } } } } @@ -347,7 +367,7 @@ int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) { continue; } - server_exchange_mr(ctx); + server_exchange_mr(ctx, has_value_cache_); } else { auto ctx_iter = connectionContexts.find(event_fd); if (ctx_iter == connectionContexts.end()) { @@ -435,18 +455,33 @@ bool RDMACommunicator::deregister_memory_regions(struct RdmaContext* ctx) { return false; } - for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) { - if (!write_mr_key_list.empty() && !write_mr_value_list.empty()) { - if (ibv_dereg_mr(write_mr_key_list[layer_idx])) { - ERR("Failed to deregister memory region: write_mr_key_list, layer %d", - layer_idx); + if (!write_mr_key_list.empty()) { + for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) { + if (write_mr_key_list[layer_idx]) { + if (ibv_dereg_mr(write_mr_key_list[layer_idx])) { + ERR("Failed to deregister memory region: write_mr_key_list, layer %d", + layer_idx); + } + write_mr_key_list[layer_idx] = nullptr; } - if (ibv_dereg_mr(write_mr_value_list[layer_idx])) { - ERR("Failed to deregister memory region: write_mr_value_list, layer %d", - layer_idx); + } + write_mr_key_list.clear(); + } + + if (!write_mr_value_list.empty()) { + for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) { + if (write_mr_value_list[layer_idx]) { + if (ibv_dereg_mr(write_mr_value_list[layer_idx])) { + ERR("Failed to deregister memory region: write_mr_value_list, layer " + "%d", + layer_idx); + } + write_mr_value_list[layer_idx] = nullptr; } } + write_mr_value_list.clear(); } + return true; } @@ -548,7 +583,7 @@ int RDMACommunicator::connect(const std::string& dst_ip, ERR("Couldn't getexchange port infodestinations"); return static_cast(ConnStatus::kError); } else { - client_exchange_mr(ctx); + client_exchange_mr(ctx, has_value_cache_); } // Allocate RDMA read and register read buffers @@ -735,15 +770,17 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) { } std::lock_guard lock(mutex_); - - if (!write_mr_key_list.empty() || !write_mr_value_list.empty()) { + if (!write_mr_key_list.empty()) { WARN("Memory regions already registered"); return true; } const size_t list_size = layer_number; write_mr_key_list.resize(list_size, nullptr); - write_mr_value_list.resize(list_size, nullptr); + + if (has_value_cache_) { + write_mr_value_list.resize(list_size, nullptr); + } const uint32_t access_flags = IBV_ACCESS_LOCAL_WRITE | @@ -753,8 +790,6 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) { for (int i = 0; i < static_cast(list_size); ++i) { void* key_ptr = reinterpret_cast(local_cache_key_ptr_layer_head_[i]); - void* val_ptr = - reinterpret_cast(local_cache_value_ptr_layer_head_[i]); size_t size = static_cast(block_size_byte) * block_number; write_mr_key_list[i] = @@ -765,13 +800,18 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) { access_flags); if (!write_mr_key_list[i]) goto fail; - write_mr_value_list[i] = - register_memory_region(ctx->pd, - val_ptr, - size, - "client_value_" + std::to_string(i), - access_flags); - if (!write_mr_value_list[i]) goto fail; + if (has_value_cache_) { + void* val_ptr = + reinterpret_cast(local_cache_value_ptr_layer_head_[i]); + + write_mr_value_list[i] = + register_memory_region(ctx->pd, + val_ptr, + size, + "client_value_" + std::to_string(i), + access_flags); + if (!write_mr_value_list[i]) goto fail; + } } return true; @@ -812,8 +852,6 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) { for (int i = 0; i < layer_number; ++i) { void* key_ptr = reinterpret_cast(local_cache_key_ptr_layer_head_[i]); - void* val_ptr = - reinterpret_cast(local_cache_value_ptr_layer_head_[i]); size_t size = static_cast(block_size_byte) * block_number; struct ibv_mr* key_mr = register_memory_region( @@ -822,21 +860,25 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) { ERR("Failed to register key MR at layer %d", i); goto fail; } + write_cache_key_server_mr_list.push_back(key_mr); - struct ibv_mr* value_mr = register_memory_region( - ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags); - if (!value_mr) { - ERR("Failed to register value MR at layer %d", i); - ibv_dereg_mr(key_mr); - goto fail; + if (has_value_cache_) { + void* val_ptr = + reinterpret_cast(local_cache_value_ptr_layer_head_[i]); + struct ibv_mr* value_mr = register_memory_region( + ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags); + if (!value_mr) { + ERR("Failed to register value MR at layer %d", i); + ibv_dereg_mr(key_mr); + goto fail; + } + write_cache_value_server_mr_list.push_back(value_mr); } - - write_cache_key_server_mr_list.push_back(key_mr); - write_cache_value_server_mr_list.push_back(value_mr); } ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list; ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list; + return true; fail: @@ -899,8 +941,12 @@ int RDMACommunicator::write_cache(const std::string& ip, uint32_t cache_key_rkey = ctx->conn.write_cache_key_remote_rkey_list[layer_idx]; - uint32_t cache_value_rkey = - ctx->conn.write_cache_value_remote_rkey_list[layer_idx]; + + uint32_t cache_value_rkey = 0; + if (has_value_cache_) { + cache_value_rkey = ctx->conn.write_cache_value_remote_rkey_list[layer_idx]; + } + uint32_t crc_cache_key_rkey, crc_cache_value_rkey; bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size; uint64_t offset_in_block = @@ -914,15 +960,19 @@ int RDMACommunicator::write_cache(const std::string& ip, cache_key_remote_addr[block_index] = (uint64_t( char_ptr + remote_block_ids[block_index] * total_block_size_byte + offset_in_block)); - char_ptr = static_cast( - ctx->conn.write_cache_value_remote_ptr_list[layer_idx]); - cache_value_remote_addr[block_index] = (uint64_t( - char_ptr + remote_block_ids[block_index] * total_block_size_byte + - offset_in_block)); + + if (has_value_cache_) { + char_ptr = static_cast( + ctx->conn.write_cache_value_remote_ptr_list[layer_idx]); + cache_value_remote_addr[block_index] = (uint64_t( + char_ptr + remote_block_ids[block_index] * total_block_size_byte + + offset_in_block)); + } } ctx->conn.wc_target_count = 0; - for (int i = 0; i < 2; ++i) { + int loop_count = has_value_cache_ ? 2 : 1; + for (int i = 0; i < loop_count; ++i) { bool is_key = (i == 0); uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey); std::vector& remote_addr = @@ -1038,6 +1088,10 @@ void RDMACommunicator::prepare_write_requests( bool is_key, std::vector& remote_addr, uint32_t rkey) { + if (!is_key) { + assert(!write_mr_value_list.empty() && + "Trying to process Value Cache but it is empty!"); + } auto block_num = local_block_ids.size(); for (size_t i = 0; i < block_num; ++i) { diff --git a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py index 0548e8f84ca..f90a5d23234 100644 --- a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py +++ b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py @@ -40,11 +40,10 @@ def __init__( try: import rdma_comm except: - logger.error( + raise RuntimeError( "The installation of the RDMA library failed." "Confirm whether your network card supports RDMA transmission." ) - return self.messager = rdma_comm.RDMACommunicator( splitwise_role, gpu_id, diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 66e950ad9fc..ec3328a6a90 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -755,8 +755,9 @@ def launch_components(self): local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) ) + ctx = multiprocessing.get_context("spawn") self.dp_processed.append( - multiprocessing.Process( + ctx.Process( target=start_data_parallel_service, args=( self.cfg, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index cda5684e604..c3437ea270f 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -205,7 +205,6 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.group_size, self.block_size, ) - # MLA metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] @@ -279,6 +278,7 @@ def forward_extend( forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, metadata.block_tables, + metadata.kv_signal_data_list[layer.layer_id], "none", getattr(forward_meta, "max_input_length", -1), ) @@ -422,10 +422,10 @@ def forward_mixed( forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, metadata.block_tables, + metadata.kv_signal_data_list[layer.layer_id], "none", self.max_seq_len, ) - # FA fmha_out = self.flash_attn_func( q, diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index 9d4913425af..87d3c254343 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -307,6 +307,7 @@ def forward_extend( forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, metadata.block_tables, + metadata.kv_signal_data_list[layer.layer_id], "none", getattr(forward_meta, "max_input_length", -1), ) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 2de8cd567ae..13d0f88319d 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -258,10 +258,10 @@ def weight_loader( else: SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} - if not param._is_initialized(): - param.initialize() if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts): return + if not param._is_initialized(): + param.initialize() weight_need_transpose = getattr(param, "weight_need_transpose", False) if shard_id is None: # 1.gate up fused in disk diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 04fa0abd09b..dd19b50f71b 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -341,6 +341,7 @@ def forward( # NOTE: (changwenbin) qkv_a_proj horizontal fusion qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) + query, compressed_kv, key_pe = qkv_a_out.split( [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1 ) @@ -399,6 +400,7 @@ def forward( self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim), ] ) + fmha_out_decode = self.mla_attn( q=q_input, k=None, @@ -418,6 +420,7 @@ def forward( .transpose([1, 0, 2]) .reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) ) + if fmha_out is None: fmha_out = fmha_out_decode else: @@ -515,6 +518,7 @@ def forward( hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) + return hidden_states, residual @@ -674,7 +678,6 @@ def load_weights(self, weights_iterator) -> None: process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config) for loaded_weight_name, loaded_weight in weights_iterator: loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model") - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: continue @@ -741,6 +744,20 @@ def pre_process(self, forward_meta): ) return position_ids, mask_encoder_batch + def empty_input_forward(self): + """ + empty_input_forward + """ + fake_hidden_states = paddle.empty( + shape=[1, self.fd_config.model_config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + for i in range( + self.fd_config.model_config.first_k_dense_replace, + self.fd_config.model_config.num_hidden_layers, + ): + self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate) + def forward( self, ids_remove_padding: paddle.Tensor, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index e4ed0de57e8..7965fcbb8e7 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2328,7 +2328,6 @@ class at the server level, which is too granular for ModelRunner. self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) - # 5. Post Process model_output_data = ModelOutputData( next_tokens=self.share_inputs["next_tokens"],