From ae15022f485c51e0e86b54e4bc8c0a0435584397 Mon Sep 17 00:00:00 2001 From: DragonFive <1690302963@qq.com> Date: Fri, 31 Oct 2025 17:56:29 +0800 Subject: [PATCH] feat: add rec framwork. --- CMakeLists.txt | 4 +- xllm/api_service/api_service.cpp | 113 +- xllm/api_service/api_service.h | 1 + xllm/api_service/completion_service_impl.cpp | 219 +- xllm/api_service/completion_service_impl.h | 16 + xllm/core/common/global_flags.cpp | 6 + xllm/core/common/global_flags.h | 2 + xllm/core/common/metrics.cpp | 14 + xllm/core/common/metrics.h | 8 + xllm/core/common/types.h | 5 + xllm/core/framework/batch/CMakeLists.txt | 2 + xllm/core/framework/batch/batch.cpp | 73 +- xllm/core/framework/batch/batch.h | 18 +- xllm/core/framework/batch/batch_factory.cpp | 75 + xllm/core/framework/batch/batch_factory.h | 14 + .../framework/batch/batch_input_builder.h | 1 + .../batch/rec_batch_input_builder.cpp | 996 +++++++++ .../framework/batch/rec_batch_input_builder.h | 134 ++ xllm/core/framework/model/model_args.h | 13 +- .../core/framework/model/model_input_params.h | 147 +- .../framework/prefix_cache/prefix_cache.cpp | 29 +- .../framework/prefix_cache/prefix_cache.h | 4 - .../prefix_cache/prefix_cache_test.cpp | 1 + xllm/core/framework/request/request.cpp | 2 + xllm/core/framework/request/request_output.h | 3 + xllm/core/framework/request/request_state.h | 6 + xllm/core/framework/request/sequence.cpp | 141 +- xllm/core/framework/request/sequence.h | 34 + .../framework/request/sequences_group.cpp | 10 +- xllm/core/framework/request/sequences_group.h | 7 + xllm/core/framework/sampling/CMakeLists.txt | 5 + .../framework/sampling/sampling_params.cpp | 10 +- .../framework/sampling/valid_path_filter.cpp | 269 +++ .../framework/sampling/valid_path_filter.h | 65 + .../sampling/valid_path_filter_test.cpp | 167 ++ xllm/core/layers/npu/CMakeLists.txt | 2 + .../npu/npu_onerec_block_layer_impl.cpp | 1880 +++++++++++++++++ .../layers/npu/npu_onerec_block_layer_impl.h | 167 ++ xllm/core/layers/onerec_block_layer.h | 42 + xllm/core/runtime/CMakeLists.txt | 6 + xllm/core/runtime/forward_params.h | 5 + xllm/core/runtime/llm_worker_impl.cpp | 16 - xllm/core/runtime/master.cpp | 37 + xllm/core/runtime/rec_engine.cpp | 341 +++ xllm/core/runtime/rec_engine.h | 80 + xllm/core/runtime/rec_master.cpp | 268 +++ xllm/core/runtime/rec_master.h | 71 + xllm/core/runtime/rec_worker_impl.cpp | 363 ++++ xllm/core/runtime/rec_worker_impl.h | 76 + xllm/core/runtime/worker.cpp | 3 + xllm/core/scheduler/CMakeLists.txt | 2 + xllm/core/scheduler/fixsteps_scheduler.cpp | 309 +++ xllm/core/scheduler/fixsteps_scheduler.h | 62 + xllm/core/scheduler/scheduler_factory.cpp | 7 + xllm/core/scheduler/scheduler_factory.h | 5 + xllm/core/util/CMakeLists.txt | 8 +- xllm/core/util/env_var.cpp | 3 + xllm/core/util/env_var.h | 3 + xllm/core/util/hash_util.cpp | 55 + xllm/core/util/hash_util.h | 6 + xllm/core/util/tensor_helper.h | 52 + xllm/core/util/utils.cpp | 98 + xllm/core/util/utils.h | 4 + xllm/models/model_registry.cpp | 1 + xllm/models/rec/onerec.h | 1054 +++++++++ xllm/proto/CMakeLists.txt | 1 + xllm/proto/completion.proto | 7 + xllm/proto/rec.proto | 119 ++ 68 files changed, 7579 insertions(+), 188 deletions(-) create mode 100644 xllm/core/framework/batch/rec_batch_input_builder.cpp create mode 100644 xllm/core/framework/batch/rec_batch_input_builder.h create mode 100644 xllm/core/framework/sampling/valid_path_filter.cpp create mode 100644 xllm/core/framework/sampling/valid_path_filter.h create mode 100644 xllm/core/framework/sampling/valid_path_filter_test.cpp create mode 100644 xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp create mode 100644 xllm/core/layers/npu/npu_onerec_block_layer_impl.h create mode 100644 xllm/core/layers/onerec_block_layer.h create mode 100644 xllm/core/runtime/rec_engine.cpp create mode 100644 xllm/core/runtime/rec_engine.h create mode 100644 xllm/core/runtime/rec_master.cpp create mode 100644 xllm/core/runtime/rec_master.h create mode 100644 xllm/core/runtime/rec_worker_impl.cpp create mode 100644 xllm/core/runtime/rec_worker_impl.h create mode 100644 xllm/core/scheduler/fixsteps_scheduler.cpp create mode 100644 xllm/core/scheduler/fixsteps_scheduler.h create mode 100644 xllm/core/util/hash_util.cpp create mode 100644 xllm/models/rec/onerec.h create mode 100644 xllm/proto/rec.proto diff --git a/CMakeLists.txt b/CMakeLists.txt index 645ce0a2..80c3c900 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,8 +27,8 @@ if(USE_NPU) if(INSTALL_XLLM_KERNELS) if(DEVICE_TYPE STREQUAL "USE_A3") message("downloading a3 arm xllm kernels") - file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.2-Linux.a3.arm.rpm" + file(DOWNLOAD + "https://9n-online-service.s3-internal.cn-north-1.jdcloud-oss.com/9n-xllm-atb/xllm_kernels-1.3.2-Linux.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) else() diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index 204b9f90..ed84221a 100755 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -27,6 +27,7 @@ limitations under the License. #include "core/common/metrics.h" #include "core/runtime/dit_master.h" #include "core/runtime/llm_master.h" +#include "core/runtime/rec_master.h" #include "core/runtime/vlm_master.h" #include "core/util/closure_guard.h" #include "embedding.pb.h" @@ -68,6 +69,9 @@ APIService::APIService(Master* master, image_generation_service_impl_ = std::make_unique( dynamic_cast(master), model_names); + } else if (FLAGS_backend == "rec") { + rec_completion_service_impl_ = std::make_unique( + dynamic_cast(master), model_names); } models_service_impl_ = ServiceImplFactory::create_service_impl( @@ -78,13 +82,6 @@ void APIService::Completions(::google::protobuf::RpcController* controller, const proto::CompletionRequest* request, proto::CompletionResponse* response, ::google::protobuf::Closure* done) { - // TODO with xllm-service -} - -void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, - const proto::HttpRequest* request, - proto::HttpResponse* response, - ::google::protobuf::Closure* done) { xllm::ClosureGuard done_guard( done, std::bind(request_in_metric, nullptr), @@ -93,47 +90,38 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, LOG(ERROR) << "brpc request | respose | controller is null"; return; } - - auto arena = response->GetArena(); - auto req_pb = - google::protobuf::Arena::CreateMessage(arena); - auto resp_pb = - google::protobuf::Arena::CreateMessage(arena); - auto ctrl = reinterpret_cast(controller); - std::string error; - json2pb::Json2PbOptions options; - butil::IOBuf& buf = ctrl->request_attachment(); - butil::IOBufAsZeroCopyInputStream iobuf_stream(buf); - auto st = json2pb::JsonToProtoMessage(&iobuf_stream, req_pb, options, &error); - if (!st) { - ctrl->SetFailed(error); - LOG(ERROR) << "parse json to proto failed: " << error; - return; - } - std::shared_ptr call = std::make_shared( - ctrl, done_guard.release(), req_pb, resp_pb); - completion_service_impl_->process_async(call); -} - -void APIService::ChatCompletions(::google::protobuf::RpcController* controller, - const proto::ChatRequest* request, - proto::ChatResponse* response, - ::google::protobuf::Closure* done) { - // TODO with xllm-service + if (FLAGS_backend == "llm") { + CHECK(completion_service_impl_) << " completion service is invalid."; + std::shared_ptr call = std::make_shared( + ctrl, + done_guard.release(), + const_cast(request), + response); + completion_service_impl_->process_async(call); + } else if (FLAGS_backend == "rec") { + CHECK(rec_completion_service_impl_) + << " rec completion service is invalid."; + std::shared_ptr call = std::make_shared( + ctrl, + done_guard.release(), + const_cast(request), + response); + rec_completion_service_impl_->process_async(call); + } } namespace { -template -void ChatCompletionsImpl(std::unique_ptr& service, - xllm::ClosureGuard& guard, - ::google::protobuf::Arena* arena, - brpc::Controller* ctrl) { +template +void CommonCompletionsImpl(std::unique_ptr& service, + xllm::ClosureGuard& guard, + ::google::protobuf::Arena* arena, + brpc::Controller* ctrl) { auto req_pb = - google::protobuf::Arena::CreateMessage(arena); + google::protobuf::Arena::CreateMessage(arena); auto resp_pb = - google::protobuf::Arena::CreateMessage(arena); + google::protobuf::Arena::CreateMessage(arena); std::string error; json2pb::Json2PbOptions options; @@ -146,12 +134,46 @@ void ChatCompletionsImpl(std::unique_ptr& service, return; } - auto call = std::make_shared( - ctrl, guard.release(), req_pb, resp_pb, arena != nullptr /*use_arena*/); + auto call = std::make_shared(ctrl, guard.release(), req_pb, resp_pb); service->process_async(call); } } // namespace +void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, + const proto::HttpRequest* request, + proto::HttpResponse* response, + ::google::protobuf::Closure* done) { + xllm::ClosureGuard done_guard( + done, + std::bind(request_in_metric, nullptr), + std::bind(request_out_metric, (void*)controller)); + if (!request || !response || !controller) { + LOG(ERROR) << "brpc request | respose | controller is null"; + return; + } + + auto arena = response->GetArena(); + auto ctrl = reinterpret_cast(controller); + + if (FLAGS_backend == "llm") { + CHECK(completion_service_impl_) << " completion service is invalid."; + CommonCompletionsImpl( + completion_service_impl_, done_guard, arena, ctrl); + } else if (FLAGS_backend == "rec") { + CHECK(rec_completion_service_impl_) + << " rec completion service is invalid."; + CommonCompletionsImpl( + rec_completion_service_impl_, done_guard, arena, ctrl); + } +} + +void APIService::ChatCompletions(::google::protobuf::RpcController* controller, + const proto::ChatRequest* request, + proto::ChatResponse* response, + ::google::protobuf::Closure* done) { + // TODO with xllm-service +} + void APIService::ChatCompletionsHttp( ::google::protobuf::RpcController* controller, const proto::HttpRequest* request, @@ -171,12 +193,11 @@ void APIService::ChatCompletionsHttp( if (FLAGS_backend == "llm") { auto arena = response->GetArena(); CHECK(chat_service_impl_) << " chat service is invalid."; - ChatCompletionsImpl( + CommonCompletionsImpl( chat_service_impl_, done_guard, arena, ctrl); } else if (FLAGS_backend == "vlm") { CHECK(mm_chat_service_impl_) << " mm chat service is invalid."; - // TODO: fix me - temporarily using heap allocation instead of arena - ChatCompletionsImpl( + CommonCompletionsImpl( mm_chat_service_impl_, done_guard, nullptr, ctrl); } } diff --git a/xllm/api_service/api_service.h b/xllm/api_service/api_service.h index 72c0e451..71ef1a87 100644 --- a/xllm/api_service/api_service.h +++ b/xllm/api_service/api_service.h @@ -123,6 +123,7 @@ class APIService : public proto::XllmAPIService { std::unique_ptr models_service_impl_; std::unique_ptr image_generation_service_impl_; std::unique_ptr rerank_service_impl_; + std::unique_ptr rec_completion_service_impl_; }; } // namespace xllm diff --git a/xllm/api_service/completion_service_impl.cpp b/xllm/api_service/completion_service_impl.cpp index d364074b..fc84b642 100644 --- a/xllm/api_service/completion_service_impl.cpp +++ b/xllm/api_service/completion_service_impl.cpp @@ -26,8 +26,10 @@ limitations under the License. #include "common/instance_name.h" #include "completion.pb.h" +#include "core/framework/request/mm_data.h" #include "core/framework/request/request_output.h" #include "core/runtime/llm_master.h" +#include "core/runtime/rec_master.h" #include "core/util/utils.h" #define likely(x) __builtin_expect(!!(x), 1) @@ -126,6 +128,7 @@ bool send_result_to_client_brpc(std::shared_ptr call, response.set_created(created_time); response.set_model(model); + // add choices into response response.mutable_choices()->Reserve(req_output.outputs.size()); for (const auto& output : req_output.outputs) { auto* choice = response.add_choices(); @@ -137,6 +140,7 @@ bool send_result_to_client_brpc(std::shared_ptr call, } } + // add usage statistics if (req_output.usage.has_value()) { const auto& usage = req_output.usage.value(); auto* proto_usage = response.mutable_usage(); @@ -147,35 +151,68 @@ bool send_result_to_client_brpc(std::shared_ptr call, proto_usage->set_total_tokens(static_cast(usage.num_total_tokens)); } - return call->write_and_finish(response); -} + if (FLAGS_backend == "rec") { + auto output_tensor = response.mutable_output_tensors()->Add(); + output_tensor->set_name("omnirec_result"); + // TODO: replace true with flags after converter merge + if (FLAGS_enable_constrained_decoding) { + output_tensor->set_datatype(proto::DataType::INT64); + output_tensor->mutable_shape()->Add(req_output.outputs.size()); + output_tensor->mutable_shape()->Add(1); // Single item per output -} // namespace + auto context = output_tensor->mutable_contents(); + for (int i = 0; i < req_output.outputs.size(); ++i) { + if (req_output.outputs[i].item_ids.has_value()) { + context->mutable_int64_contents()->Add( + req_output.outputs[i].item_ids.value()); + } + } + } else { + output_tensor->set_datatype(proto::DataType::INT32); -CompletionServiceImpl::CompletionServiceImpl( - LLMMaster* master, - const std::vector& models) - : APIServiceImpl(models), master_(master) { - CHECK(master_ != nullptr); + output_tensor->mutable_shape()->Add(req_output.outputs.size()); + output_tensor->mutable_shape()->Add( + req_output.outputs[0].token_ids.size()); + + auto context = output_tensor->mutable_contents(); + for (int i = 0; i < req_output.outputs.size(); ++i) { + // LOG(INFO) << req_output.outputs[i].token_ids; + context->mutable_int_contents()->Add( + req_output.outputs[i].token_ids.begin(), + req_output.outputs[i].token_ids.end()); + } + } + } + + return call->write_and_finish(response); } -// complete_async for brpc -void CompletionServiceImpl::process_async_impl( - std::shared_ptr call) { +// Type alias for the return type of process_completion_request_params +using ProcessCompletionResult = + std::optional>, + bool, + std::string>>; +// Common function to process request parameters and validation +ProcessCompletionResult process_completion_request_params( + std::shared_ptr call, + const absl::flat_hash_set& models, + xllm::RateLimiter* rate_limiter) { const auto& rpc_request = call->request(); + // check if model is supported const auto& model = rpc_request.model(); - if (unlikely(!models_.contains(model))) { + if (unlikely(!models.contains(model))) { call->finish_with_error(StatusCode::UNKNOWN, "Model not supported"); - return; + return std::nullopt; } // Check if the request is being rate-limited. - if (unlikely(master_->get_rate_limiter()->is_limited())) { + if (unlikely(rate_limiter->is_limited())) { call->finish_with_error( StatusCode::RESOURCE_EXHAUSTED, "The number of concurrent requests has reached the limit."); - return; + return std::nullopt; } RequestParams request_params( @@ -196,44 +233,126 @@ void CompletionServiceImpl::process_async_impl( request_params.decode_address = rpc_request.routing().decode_name(); } + return std::make_tuple(std::move(request_params), + std::move(prompt_tokens), + include_usage, + model); +} + +// Common callback function for handling request output +auto request_callback(std::shared_ptr call, + const std::string& model, + Master* master, + bool stream, + bool include_usage, + const std::string& request_id, + int64_t created_time) { + return [call, model, master, stream, include_usage, request_id, created_time]( + const RequestOutput& req_output) -> bool { + if (req_output.status.has_value()) { + const auto& status = req_output.status.value(); + if (!status.ok()) { + // Reduce the number of concurrent requests when a request is + // finished with error. + master->get_rate_limiter()->decrease_one_request(); + + return call->finish_with_error(status.code(), status.message()); + } + } + + // Reduce the number of concurrent requests when a request is finished + // or canceled. + if (req_output.finished || req_output.cancelled) { + master->get_rate_limiter()->decrease_one_request(); + } + + if (stream) { + return send_delta_to_client_brpc( + call, include_usage, request_id, created_time, model, req_output); + } + return send_result_to_client_brpc( + call, request_id, created_time, model, req_output); + }; +} + +} // namespace + +CompletionServiceImpl::CompletionServiceImpl( + LLMMaster* master, + const std::vector& models) + : APIServiceImpl(models), master_(master) { + CHECK(master_ != nullptr); +} + +// complete_async for brpc +void CompletionServiceImpl::process_async_impl( + std::shared_ptr call) { + auto result = process_completion_request_params( + call, models_, master_->get_rate_limiter()); + if (!result.has_value()) { + return; // Error already handled in process_completion_request_params + } + + auto [request_params, prompt_tokens, include_usage, model] = + std::move(result.value()); // schedule the request - master_->handle_request( - std::move(rpc_request.prompt()), - std::move(prompt_tokens), - std::move(request_params), - call.get(), - [call, - model, - master = master_, - stream = request_params.streaming, - include_usage = include_usage, - request_id = request_params.request_id, - created_time = absl::ToUnixSeconds(absl::Now())]( - const RequestOutput& req_output) -> bool { - if (req_output.status.has_value()) { - const auto& status = req_output.status.value(); - if (!status.ok()) { - // Reduce the number of concurrent requests when a request is - // finished with error. - master->get_rate_limiter()->decrease_one_request(); - - return call->finish_with_error(status.code(), status.message()); - } - } + master_->handle_request(std::move(call->request().prompt()), + std::move(prompt_tokens), + std::move(request_params), + call.get(), + request_callback(call, + model, + master_, + request_params.streaming, + include_usage, + request_params.request_id, + absl::ToUnixSeconds(absl::Now()))); +} - // Reduce the number of concurrent requests when a request is finished - // or canceled. - if (req_output.finished || req_output.cancelled) { - master->get_rate_limiter()->decrease_one_request(); - } +RecCompletionServiceImpl::RecCompletionServiceImpl( + RecMaster* master, + const std::vector& models) + : APIServiceImpl(models), master_(master) { + CHECK(master_ != nullptr); +} - if (stream) { - return send_delta_to_client_brpc( - call, include_usage, request_id, created_time, model, req_output); - } - return send_result_to_client_brpc( - call, request_id, created_time, model, req_output); - }); +void RecCompletionServiceImpl::process_async_impl( + std::shared_ptr call) { + auto result = process_completion_request_params( + call, models_, master_->get_rate_limiter()); + if (!result.has_value()) { + return; // Error already handled in process_completion_request_params + } + + auto [request_params, prompt_tokens, include_usage, model] = + std::move(result.value()); + const auto& rpc_request = call->request(); + std::optional mm_data = std::nullopt; + if (rpc_request.input_tensors_size()) { + // HISTOGRAM_OBSERVE(rec_input_first_dim, + // rpc_request.input_tensors(0).shape(0)); + + MMDict mm_dict; + for (int i = 0; i < rpc_request.input_tensors_size(); ++i) { + const auto& tensor = rpc_request.input_tensors(i); + mm_dict[tensor.name()] = + xllm::util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16); + } + mm_data = std::move(MMData(MMType::EMBEDDING, mm_dict)); + } + + // schedule the request + master_->handle_request(std::move(rpc_request.prompt()), + std::move(prompt_tokens), + std::move(mm_data), + std::move(request_params), + request_callback(call, + model, + master_, + request_params.streaming, + include_usage, + request_params.request_id, + absl::ToUnixSeconds(absl::Now()))); } } // namespace xllm diff --git a/xllm/api_service/completion_service_impl.h b/xllm/api_service/completion_service_impl.h index 5fdc74f5..12b823a9 100644 --- a/xllm/api_service/completion_service_impl.h +++ b/xllm/api_service/completion_service_impl.h @@ -20,6 +20,7 @@ limitations under the License. #include "api_service_impl.h" #include "completion.pb.h" +#include "rec.pb.h" #include "stream_call.h" namespace xllm { @@ -41,4 +42,19 @@ class CompletionServiceImpl final : public APIServiceImpl { LLMMaster* master_ = nullptr; }; +class RecMaster; +// a class to handle completion requests +class RecCompletionServiceImpl final : public APIServiceImpl { + public: + RecCompletionServiceImpl(RecMaster* master, + const std::vector& models); + + // brpc call_data needs to use shared_ptr + void process_async_impl(std::shared_ptr call); + + private: + DISALLOW_COPY_AND_ASSIGN(RecCompletionServiceImpl); + RecMaster* master_ = nullptr; +}; + } // namespace xllm diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index d0ef1e0e..48afdd58 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -396,3 +396,9 @@ DEFINE_bool( "Whether to enable prefetch weight,only applicable to Qwen3-dense model." "The default prefetching ratio for gateup weight is 40%." "If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5"); + +// rec prefill-only mode +DEFINE_bool(enable_rec_prefill_only, + false, + "Enable rec prefill-only mode (no decoder self-attention blocks " + "allocation)"); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 7fc36442..e7287e0c 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -204,3 +204,5 @@ DECLARE_string(reasoning_parser); DECLARE_bool(enable_shm); DECLARE_bool(enable_prefetch_weight); + +DECLARE_bool(enable_rec_prefill_only); diff --git a/xllm/core/common/metrics.cpp b/xllm/core/common/metrics.cpp index 5f792f2b..3af87f54 100644 --- a/xllm/core/common/metrics.cpp +++ b/xllm/core/common/metrics.cpp @@ -180,6 +180,20 @@ DEFINE_COUNTER(proto_latency_seconds_o2proto, // engine metrics DEFINE_COUNTER(prepare_input_latency_seconds, "Latency of preparing input in seconds"); +DEFINE_COUNTER(prepare_input_latency_microseconds, + "Latency of preparing input in microseconds"); + +// rec engine metrics +DEFINE_COUNTER(rec_first_token_latency_microseconds, + "Latency of rec first token generation in microseconds"); +DEFINE_COUNTER(rec_second_token_latency_microseconds, + "Latency of rec second token generation in microseconds"); +DEFINE_COUNTER(rec_third_token_latency_microseconds, + "Latency of rec third token generation in microseconds"); +DEFINE_COUNTER(rec_sampling_latency_microseconds, + "Latency of rec sampling in microseconds"); +DEFINE_HISTOGRAM(expand_beam_latency_microseconds, + "Histogram of expand beam latency in microseconds"); // multi node metrics DEFINE_COUNTER(worker_service_latency_seconds, diff --git a/xllm/core/common/metrics.h b/xllm/core/common/metrics.h index 48663341..82c9f231 100644 --- a/xllm/core/common/metrics.h +++ b/xllm/core/common/metrics.h @@ -205,6 +205,14 @@ DECLARE_COUNTER(proto_latency_seconds_o2proto); // engine metrics DECLARE_COUNTER(prepare_input_latency_seconds); +// rec engine metrics +DECLARE_COUNTER(prepare_input_latency_microseconds); +DECLARE_COUNTER(rec_first_token_latency_microseconds); +DECLARE_COUNTER(rec_second_token_latency_microseconds); +DECLARE_COUNTER(rec_third_token_latency_microseconds); +DECLARE_COUNTER(rec_sampling_latency_microseconds); +DECLARE_HISTOGRAM(expand_beam_latency_microseconds); + // multi node metrics DECLARE_COUNTER(worker_service_latency_seconds); DECLARE_COUNTER(engine_latency_seconds); diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index da969dbf..4a98c30a 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -31,6 +31,7 @@ class EngineType { SSM = 1, VLM = 2, DIT = 3, + REC = 4, INVALID = -1, }; @@ -44,6 +45,8 @@ class EngineType { value_ = VLM; } else if (str == "DIT") { value_ = DIT; + } else if (str == "REC") { + value_ = REC; } else { value_ = INVALID; } @@ -68,6 +71,8 @@ class EngineType { return "VLM"; } else if (this->value_ == DIT) { return "DIT"; + } else if (this->value_ == REC) { + return "REC"; } else { return "INVALID"; } diff --git a/xllm/core/framework/batch/CMakeLists.txt b/xllm/core/framework/batch/CMakeLists.txt index 94d20240..9676e906 100644 --- a/xllm/core/framework/batch/CMakeLists.txt +++ b/xllm/core/framework/batch/CMakeLists.txt @@ -10,12 +10,14 @@ cc_library( batch.h batch_factory.h batch_input_builder.h + rec_batch_input_builder.h mposition.h SRCS dit_batch.cpp batch.cpp batch_factory.cpp batch_input_builder.cpp + rec_batch_input_builder.cpp mposition.cpp beam_search.h DEPS diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index d2de8049..d5b6ca8e 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/request/sequence.h" #include "framework/sampling/sampling_params.h" +#include "rec_batch_input_builder.h" #include "runtime/params_utils.h" #include "util/slice.h" #include "util/tensor_helper.h" @@ -68,7 +69,12 @@ void Batch::add(const std::vector& sequences) { ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size, - const ModelArgs& args) { + const ModelArgs& args, + ThreadPool* thread_pool) { + if (FLAGS_backend == "rec") { + return prepare_rec_forward_input( + num_decoding_tokens, min_decoding_batch_size, args, thread_pool); + } BatchInputBuilder builder(sequences_, allowed_max_tokens_, input_embeddings_vec_, @@ -81,6 +87,58 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens, min_decoding_batch_size); } +ForwardInput Batch::prepare_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size, + const ModelArgs& args, + ThreadPool* thread_pool) { + // Convert SequencesGroup* to std::unique_ptr for + // compatibility + std::vector> sequence_groups_ptrs; + for (auto* group : sequence_groups_) { + // Note: This is a temporary workaround. In production, we should avoid this + // conversion and modify the interface to work with raw pointers directly. + sequence_groups_ptrs.emplace_back(std::unique_ptr(group)); + } + + RecBatchInputBuilder builder( + sequence_groups_ptrs, + allowed_max_tokens_, + input_embeddings_vec_, + mm_data_vec_, + copy_in_cache_block_infos_, + copy_out_cache_block_infos_, + swap_cache_block_infos_, + &args, + thread_pool); // Temporarily not using thread pool + + auto result = builder.build_rec_forward_input(num_decoding_tokens, + min_decoding_batch_size); + + // Release the unique_ptrs without deleting the objects + for (auto& ptr : sequence_groups_ptrs) { + ptr.release(); + } + + return result; +} + +std::vector Batch::get_sequences() const { + // If sequences_ is not empty, return it directly + if (!sequences_.empty()) { + return sequences_; + } + + // Otherwise, extract sequences from sequence_groups_ + std::vector result; + for (const auto* seq_group : sequence_groups_) { + const auto& sequences = seq_group->sequences(); + for (const auto& seq_ptr : sequences) { + result.push_back(seq_ptr.get()); + } + } + return result; +} + RawForwardInput Batch::prepare_forward_input(uint32_t start_idx, uint32_t end_idx, ThreadPool* thread_pool) { @@ -149,7 +207,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output, if (sample_output.embeddings.defined()) { const int64_t num_seqs = sample_output.embeddings.size(0); int64_t output_idx = 0; - for (auto* seq : sequences_) { + const auto& sequences = get_sequences(); + for (auto* seq : sequences) { CHECK_LT(output_idx, num_seqs); auto cur_seq_embed = safe_to(sample_output.embeddings[output_idx++], torch::kFloat32); @@ -162,7 +221,8 @@ void Batch::process_sample_output(const SampleOutput& sample_output, // this means all sequences are in prefill stage status. const int64_t num_seqs = sample_output.next_tokens.size(0); int64_t output_idx = 0; - for (auto* seq : sequences_) { + const auto& sequences = get_sequences(); + for (auto* seq : sequences) { if (seq->finished()) { output_idx++; continue; @@ -338,4 +398,11 @@ void Batch::process_beam_search_output(const RawForwardOutput& raw_output, update_for_sequence_group(sequence_group_id); } } + +void Batch::finish() { + // Finish all sequence groups + for (auto* sequence_group : sequence_groups_) { + sequence_group->finish(); + } +} } // namespace xllm diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index f862b305..c3833bcc 100644 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -65,14 +65,15 @@ class Batch { // get the number of sequences in the batch size_t size() const { return sequences_.size(); } - bool empty() const { return sequences_.empty(); } + bool empty() const { return sequences_.empty() && sequence_groups_.empty(); } Sequence* operator[](size_t i) { return sequences_[i]; } // prepare forward inputs ForwardInput prepare_forward_input(uint32_t num_decoding_tokens, uint32_t min_decoding_bach_size, - const ModelArgs& args); + const ModelArgs& args, + ThreadPool* thread_pool = nullptr); // Convert Batch to pb type, which will be pass to remote worker. RawForwardInput prepare_forward_input(uint32_t start_idx, @@ -111,6 +112,19 @@ class Batch { bool get_batch_prefill_status() const { return all_seqs_in_prefill_; } + void finish(); + + // prepare forward inputs for Rec model + ForwardInput prepare_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size, + const ModelArgs& args, + ThreadPool* thread_pool = nullptr); + + protected: + // Get sequences for iteration - returns sequences_ if not empty, + // otherwise extracts sequences from sequence_groups_ + std::vector get_sequences() const; + private: bool update_sequence_state(Sequence* seq, bool replace_fake_token); diff --git a/xllm/core/framework/batch/batch_factory.cpp b/xllm/core/framework/batch/batch_factory.cpp index 5dd9d428..97ba94de 100644 --- a/xllm/core/framework/batch/batch_factory.cpp +++ b/xllm/core/framework/batch/batch_factory.cpp @@ -106,4 +106,79 @@ std::vector BatchFactory::create_batches( return batches; } +std::vector BatchFactory::create_rec_batches( + const std::vector>& running_requests, + const std::vector& running_sequences, + const std::vector& running_sequences_budgets, + std::vector>* copy_in_cache_block_infos, + std::vector>* copy_out_cache_block_infos, + std::vector>* swap_cache_block_infos) { + size_t num_prompt_tokens = 0; + size_t num_generated_tokens = 0; + std::vector batches(dp_size_); + for (size_t i = 0; i < running_sequences.size(); ++i) { + auto* sequence = running_sequences[i]; + const size_t token_budget = running_sequences_budgets[i]; + + const size_t remaining_prompt_tokens = + sequence->num_prompt_tokens() > + sequence->kv_state().kv_cache_tokens_num() + ? sequence->num_prompt_tokens() - + sequence->kv_state().kv_cache_tokens_num() + : 0; + const size_t prompt_tokens = + std::min(remaining_prompt_tokens, token_budget); + const size_t generated_tokens = token_budget - prompt_tokens; + num_prompt_tokens += prompt_tokens; + num_generated_tokens += generated_tokens; + + // if dp enabled, each sequence is required to + // dispatch to the same rank in the whole lifetime + // batches[sequence->dp_rank()].add(sequence, token_budget); + if (!((sequence->stage() == SequenceStage::DECODE) && + (sequence->kv_state().kv_cache_tokens_num() > 0))) { + batches[sequence->dp_rank()].set_batch_prefill_status(true); + } + } + // for rec, only use seq_group to prepare_input. + for (const auto& request : running_requests) { + auto seq_group = request->sequence_group(); + int32_t dp_rank = seq_group->dp_rank(); + batches[dp_rank].add(seq_group); + } + + for (int i = 0; i < dp_size_; i++) { + if (!batches[i].empty()) { + if (copy_in_cache_block_infos != nullptr && + copy_in_cache_block_infos->size() == dp_size_) { + batches[i].set_copy_in_cache_block_infos( + &(copy_in_cache_block_infos->at(i))); + } + if (copy_out_cache_block_infos != nullptr && + copy_out_cache_block_infos->size() == dp_size_) { + batches[i].set_copy_out_cache_block_infos( + &(copy_out_cache_block_infos->at(i))); + } + if (swap_cache_block_infos != nullptr && + swap_cache_block_infos->size() == dp_size_) { + batches[i].set_swap_cache_block_infos(&(swap_cache_block_infos->at(i))); + } + } + } + + COUNTER_ADD(num_processing_tokens_total_prompt, num_prompt_tokens); + COUNTER_ADD(num_processing_tokens_total_generated, num_generated_tokens); + + if (running_sequences.size() > 0) { + HISTOGRAM_OBSERVE( + num_prompt_tokens_per_request, + static_cast(num_prompt_tokens / running_sequences.size())); + HISTOGRAM_OBSERVE( + num_generated_tokens_per_request, + static_cast(num_generated_tokens / running_sequences.size())); + } + + return batches; +} + } // namespace xllm diff --git a/xllm/core/framework/batch/batch_factory.h b/xllm/core/framework/batch/batch_factory.h index 44106771..10627400 100644 --- a/xllm/core/framework/batch/batch_factory.h +++ b/xllm/core/framework/batch/batch_factory.h @@ -41,6 +41,20 @@ class BatchFactory { std::vector>* swap_cache_block_infos = nullptr); + std::vector create_rec_batches( + const std::vector>& running_requests, + const std::vector& running_sequences, + const std::vector& running_sequences_budgets, + // for global kv cache copy block from host to device + std::vector>* copy_in_cache_block_infos = + nullptr, + // for global kv cache copy block from device to host + std::vector>* copy_out_cache_block_infos = + nullptr, + // for beam-search + std::vector>* swap_cache_block_infos = + nullptr); + private: BatchFactory(int32_t dp_size) : dp_size_(dp_size) {} ~BatchFactory() = default; diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 9b76bfb1..c8171731 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -137,6 +137,7 @@ class BatchInputBuilder { uint32_t q_seq_len, BuilderState* state_ptr = nullptr); + protected: // Input data const std::vector& sequences_; const std::vector& allowed_max_tokens_; diff --git a/xllm/core/framework/batch/rec_batch_input_builder.cpp b/xllm/core/framework/batch/rec_batch_input_builder.cpp new file mode 100644 index 00000000..23d8a963 --- /dev/null +++ b/xllm/core/framework/batch/rec_batch_input_builder.cpp @@ -0,0 +1,996 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_batch_input_builder.h" + +#include +#include +#include +#include +#include +#include + +#include "framework/model/model_args.h" +#include "framework/model/model_input_params.h" +#include "framework/request/sequence.h" +#include "framework/sampling/sampling_params.h" +#include "util/tensor_helper.h" +#include "util/threadpool.h" +#include "util/utils.h" + +namespace xllm { + +// Static member definition +RecBatchInputBuilder::HighPerformanceCache RecBatchInputBuilder::perf_cache_; + +RecBatchInputBuilder::RecBatchInputBuilder( + const std::vector>& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + const std::vector* copy_in_cache_block_infos, + const std::vector* copy_out_cache_block_infos, + std::vector* swap_cache_block_infos, + const ModelArgs* args, + ThreadPool* thread_pool) + : BatchInputBuilder( // extract_sequences_from_groups(sequence_groups), + {}, + allowed_max_tokens, + input_embeddings_vec, + mm_data_vec, + copy_in_cache_block_infos, + copy_out_cache_block_infos, + swap_cache_block_infos, + args, + thread_pool), + sequence_groups_(sequence_groups) { + // Reset high performance cache + perf_cache_.memory_pool.reset(); +} + +std::vector RecBatchInputBuilder::extract_sequences_from_groups( + const std::vector>& sequence_groups) { + std::vector sequences; + for (const auto& group : sequence_groups) { + for (const auto& seq : group->sequences()) { + sequences.push_back(seq.get()); + } + } + return sequences; +} + +ForwardInput RecBatchInputBuilder::build_rec_forward_input( + uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size) { + // ========== Global constant cache ========== + static const std::vector FIXED_POSITIONS = {0}; + static const torch::Tensor FIXED_ENCODER_POSITIONS = + torch::tensor({0}, torch::kInt); + + // ========== Fast sequence information extraction ========== + const int32_t num_sequences = + !sequence_groups_.empty() + ? std::accumulate(sequence_groups_.begin(), + sequence_groups_.end(), + 0, + [](int sum, const auto& group) { + return sum + group->sequences().size(); + }) + : 0; + const int32_t THREADPOOL_THRESHOLD = 16; + if (UNLIKELY(num_sequences == 0)) { + return ForwardInput{}; + } + + // Get basic information of first sequence - optimize pointer access + Sequence* first_sequence = nullptr; + if (!sequence_groups_.empty() && !sequence_groups_[0]->sequences().empty()) { + first_sequence = sequence_groups_[0]->sequences()[0].get(); + } + + if (!first_sequence) { + return ForwardInput{}; + } + + const uint32_t seq_len = first_sequence->num_tokens(); + const uint32_t num_decoder_embeddings = + first_sequence->num_decoder_embeddings(); + const uint32_t n_prompt_tokens = first_sequence->num_prompt_tokens(); + const bool is_first_prefill = (first_sequence->num_generated_tokens() == 0); + // const uint64_t model_version = first_sequence->get_model_version(); + + // ========== High-performance encoder tokens construction ========== + auto buildEncoderTokensOptimized = [&]() -> const std::vector& { + auto& cache_data = perf_cache_.cache_data; + + // encoder doesn't use cache key, because encoder doesn't use encoder_tokens + // in non-first prefill scenarios, only uses encoder_seq_len + if (!is_first_prefill) { + return cache_data.encoder_tokens; + } + + // Optimization: Use SIMD-friendly memory access patterns + cache_data.encoder_tokens.clear(); + cache_data.encoder_seq_lens.clear(); + + // Optimization for scenarios where sequences have different lengths across + // sequence groups Pre-calculate total token count to avoid multiple memory + // reallocations + int32_t total_tokens = 0; + for (const auto& group_ptr : sequence_groups_) { + if (!group_ptr->sequences().empty()) { + // Sequences within group have same length, only need to get first + // sequence's length + const int32_t group_encoder_seq_len = + group_ptr->sequences()[0]->encoder_tokens().size(); + total_tokens += group_encoder_seq_len * group_ptr->sequences().size(); + } + } + + cache_data.encoder_tokens.reserve(total_tokens); + cache_data.encoder_seq_lens.resize(num_sequences); + cache_data.encoder_sparse_embeddings.clear(); + cache_data.encoder_sparse_embeddings.reserve(num_sequences); + cache_data.decoder_context_embeddings.clear(); + cache_data.decoder_context_embeddings.reserve(num_sequences); + + // Process by groups in batch + int32_t global_seq_idx = 0; + for (const auto& group_ptr : sequence_groups_) { + const auto& group = *group_ptr; + const int32_t group_size = group.sequences().size(); + + if (group_size == 0) continue; + + const int32_t group_encoder_seq_len = + group.sequences()[0]->encoder_seq_len(); + + // Batch set same values + std::fill_n(&cache_data.encoder_seq_lens[global_seq_idx], + group_size, + group_encoder_seq_len); + + // Batch copy tokens by sequence and collect sparse_embedding + for (const auto& sequence : group.sequences()) { + const auto& encoder_tokens = sequence->encoder_tokens(); + const int32_t* src_ptr = encoder_tokens.data(); + const int32_t group_encoder_seq_len = encoder_tokens.size(); + + // Use efficient batch insertion + if (group_encoder_seq_len > 0) { + cache_data.encoder_tokens.insert(cache_data.encoder_tokens.end(), + src_ptr, + src_ptr + group_encoder_seq_len); + } + // Collect sparse_embedding + auto mm_data = sequence->get_mm_data(); + auto sparse_embedding_optional = + mm_data.get(Sequence::ENCODER_SPARSE_EMBEDDING_NAME); + if (sparse_embedding_optional.has_value()) { + cache_data.encoder_sparse_embeddings.push_back( + sparse_embedding_optional.value()); + } + + auto decoder_context_embedding_optional = mm_data.get( + Sequence::DECODER_CONTEXT_EMBEDDING_NAME); + if (decoder_context_embedding_optional.has_value()) { + cache_data.decoder_context_embeddings.push_back( + decoder_context_embedding_optional.value()); + } + } + + global_seq_idx += group_size; + } + + return cache_data.encoder_tokens; + }; + + // ========== High-performance decoder data construction ========== + auto buildDecoderDataOptimized = [&]() { + // Pre-allocate all containers to avoid dynamic expansion + const size_t total_tokens = num_sequences * seq_len; + std::vector flatten_tokens_vec; + flatten_tokens_vec.reserve(total_tokens); + std::vector sampling_params; + sampling_params.reserve(num_sequences); + std::vector selected_token_idxes; + selected_token_idxes.reserve(num_sequences); + std::vector sample_idxes; + sample_idxes.reserve(num_sequences); + std::vector> generated_tokens; + generated_tokens.reserve(num_sequences); + + // Multi-threading optimization: Use parallel processing when sequence count + // exceeds threshold and thread pool is available + ThreadPool* threadpool = thread_pool_; + if (num_sequences >= THREADPOOL_THRESHOLD && threadpool != nullptr) { + // Thread-safe result containers + std::vector> thread_flatten_tokens(num_sequences); + std::vector thread_sampling_params( + num_sequences); + std::vector thread_selected_token_idxes(num_sequences); + std::vector thread_sample_idxes(num_sequences); + std::vector> thread_generated_tokens(num_sequences); + + // Calculate thread allocation + const size_t num_threads = + std::min(static_cast(num_sequences), static_cast(16)); + const size_t sequences_per_thread = + (num_sequences + num_threads - 1) / num_threads; + + std::vector> futures; + std::vector>> promises; + futures.reserve(num_threads); + promises.reserve(num_threads); + + // Parallel processing function + auto process_sequences_range = [&](size_t start_idx, size_t end_idx) { + for (size_t i = start_idx; + i < end_idx && i < static_cast(num_sequences); + ++i) { + const Sequence* sequence = nullptr; + // Get sequence from sequence_groups + size_t seq_idx = 0; + for (const auto& group : sequence_groups_) { + if (seq_idx + group->sequences().size() > i) { + sequence = group->sequences()[i - seq_idx].get(); + break; + } + seq_idx += group->sequences().size(); + } + + if (!sequence) continue; + + const auto& token_ids = sequence->tokens(); + + // Build generated tokens + auto& cur_generated_tokens = thread_generated_tokens[i]; + cur_generated_tokens.reserve(seq_len - n_prompt_tokens); + for (uint32_t j = n_prompt_tokens; j < seq_len; ++j) { + cur_generated_tokens.push_back(token_ids[j]); + } + + // Build flatten tokens + auto& cur_flatten_tokens = thread_flatten_tokens[i]; + cur_flatten_tokens.reserve(seq_len); + cur_flatten_tokens.insert(cur_flatten_tokens.end(), + token_ids.begin(), + token_ids.begin() + seq_len); + + // Set sampling parameters + thread_sampling_params[i] = sequence->sampling_param(); + thread_sample_idxes[i] = static_cast(i); + } + }; + + // Launch parallel tasks + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + size_t start_idx = thread_idx * sequences_per_thread; + size_t end_idx = std::min(start_idx + sequences_per_thread, + static_cast(num_sequences)); + + if (start_idx >= static_cast(num_sequences)) break; + + auto promise = std::make_shared>(); + futures.push_back(promise->get_future()); + promises.push_back(promise); + + threadpool->schedule( + [process_sequences_range, start_idx, end_idx, promise]() mutable { + try { + process_sequences_range(start_idx, end_idx); + promise->set_value(); + } catch (...) { + promise->set_exception(std::current_exception()); + } + }); + } + + // Wait for all tasks to complete + for (auto& future : futures) { + future.get(); + } + + // Merge results + size_t start_idx = 0; + size_t total_tokens = seq_len + num_decoder_embeddings; + for (int32_t i = 0; i < num_sequences; ++i) { + flatten_tokens_vec.insert(flatten_tokens_vec.end(), + thread_flatten_tokens[i].begin(), + thread_flatten_tokens[i].end()); + selected_token_idxes.push_back( + static_cast(start_idx + total_tokens - 1)); + start_idx += total_tokens; + sampling_params.push_back(thread_sampling_params[i]); + sample_idxes.push_back(thread_sample_idxes[i]); + generated_tokens.push_back(std::move(thread_generated_tokens[i])); + } + } else { + // Original single-thread processing logic + size_t start_idx = 0; + size_t total_tokens = seq_len + num_decoder_embeddings; + size_t seq_idx = 0; + for (const auto& group : sequence_groups_) { + for (const auto& sequence : group->sequences()) { + const auto& token_ids = sequence->tokens(); + + // Optimize generated tokens construction + auto& cur_generated_tokens = generated_tokens.emplace_back(); + cur_generated_tokens.reserve(seq_len - n_prompt_tokens); + for (uint32_t j = n_prompt_tokens; j < seq_len; ++j) { + cur_generated_tokens.push_back(token_ids[j]); + } + // Optimize token processing - batch operations + flatten_tokens_vec.insert(flatten_tokens_vec.end(), + token_ids.begin(), + token_ids.begin() + seq_len); + + // Simplify sampling parameter processing + selected_token_idxes.push_back( + static_cast(start_idx + total_tokens - 1)); + start_idx += total_tokens; + sampling_params.push_back(sequence->sampling_param()); + sample_idxes.push_back(seq_idx); + seq_idx++; + } + } + } + + return std::make_tuple(std::move(flatten_tokens_vec), + std::move(sampling_params), + std::move(selected_token_idxes), + std::move(sample_idxes), + std::move(generated_tokens)); + }; + + // ========== Comprehensive parallel execution of optimized data construction + // ========== Use thread pool to execute all independent data construction + // tasks in parallel + std::future&> encoder_future; + std::future, + std::vector, + std::vector, + std::vector, + std::vector>>> + decoder_future; + + // Declare variables to store results + const std::vector* encoder_tokens_ptr = nullptr; + std::vector flatten_tokens_vec; + std::vector sampling_params; + std::vector selected_token_idxes; + std::vector sample_idxes; + std::vector> generated_tokens; + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Use ThreadPool's schedule method to execute independent tasks in parallel + // buildDecoderDataOptimized handles multi-threading internally, no external + // parallel calls + + // Task 1: buildEncoderTokensOptimized + std::promise*> encoder_promise; + auto encoder_future = encoder_promise.get_future(); + thread_pool_->schedule([&, promise = std::move(encoder_promise)]() mutable { + const auto& result = buildEncoderTokensOptimized(); + promise.set_value(&result); + }); + // Wait for encoder to complete + encoder_tokens_ptr = encoder_future.get(); + // Task 2: buildDecoderDataOptimized executes directly, handles + // multi-threading internally + std::tie(flatten_tokens_vec, + sampling_params, + selected_token_idxes, + sample_idxes, + generated_tokens) = buildDecoderDataOptimized(); + } else { + // Single-thread execution (original logic) + encoder_tokens_ptr = &buildEncoderTokensOptimized(); + std::tie(flatten_tokens_vec, + sampling_params, + selected_token_idxes, + sample_idxes, + generated_tokens) = buildDecoderDataOptimized(); + } + + const auto& encoder_tokens = *encoder_tokens_ptr; + + // ========== High-performance ForwardInput construction ========== + ForwardInput forward_input; + auto& input_params = forward_input.input_params; + auto& cache_data = perf_cache_.cache_data; + + // Initialize key fields for asynchronous tasks + const int64_t bs = sequence_groups_.size(); + const int64_t group_width = + sequence_groups_.empty() ? 1 : sequence_groups_[0]->sequences().size(); + + std::vector> decoder_embedding_futures; + torch::Tensor result_embedding; + + // ========== Parallel tensor construction tasks ========== + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Only use parallelization for time-consuming tasks (token_ids and + // encoder_token_ids) + std::promise token_ids_promise; + std::promise encoder_token_ids_promise; + + auto token_ids_future = token_ids_promise.get_future(); + // auto encoder_token_ids_future = encoder_token_ids_promise.get_future(); + + // Task 1: Build token_ids tensor - + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + thread_pool_->schedule([&flatten_tokens_vec, + promise = std::move(token_ids_promise)]() mutable { + try { + // Optimization: Pre-allocate memory and use std::memcpy to avoid clone + // operations + auto tensor = + torch::empty({static_cast(flatten_tokens_vec.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(tensor.data_ptr(), + flatten_tokens_vec.data(), + flatten_tokens_vec.size() * sizeof(int)); + promise.set_value(std::move(tensor)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + + // Task 2: Build encoder_token_ids tensor (if needed) - + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + /* + thread_pool_->schedule( + [&encoder_tokens, + promise = std::move(encoder_token_ids_promise)]() mutable { + try { + torch::Tensor tensor; + if (!encoder_tokens.empty()) { + // Optimization: Pre-allocate memory and use std::memcpy to avoid + // clone operations + tensor = + torch::empty({static_cast(encoder_tokens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(tensor.data_ptr(), + encoder_tokens.data(), + encoder_tokens.size() * sizeof(int)); + } + promise.set_value(std::move(tensor)); + } catch (...) { + promise.set_exception(std::current_exception()); + } + }); + */ + if (!perf_cache_.cache_data.decoder_context_embeddings.empty()) { + // Task 3: Synchronously process decoder_embedding, inner group dimension + // parallelization optimization + + // Optimization: Directly get shape information from first embedding to + // avoid torch::cat + auto first_embedding = + perf_cache_.cache_data.decoder_context_embeddings[0]; + auto original_shape = first_embedding.sizes(); + int64_t context_len = original_shape[0]; + int64_t hidden_size = original_shape[1]; + + // Create tensor on pinned memory + auto options = torch::TensorOptions() + .dtype(first_embedding.dtype()) + .device(first_embedding.device()) + .pinned_memory(true) + .memory_format(torch::MemoryFormat::Contiguous); + + // Calculate total sequence length, pre-allocate context_len + seq_len + int64_t total_seq_len = context_len + seq_len; + + auto combined_embedding = + torch::empty({bs, group_width, total_seq_len, hidden_size}, options); + + // High-performance optimization: group dimension segmented + // parallelization + void* dst_data = combined_embedding.data_ptr(); + + // Get element size (supports float, bfloat16 and other types) + const size_t element_size = first_embedding.element_size(); + const size_t context_size = context_len * hidden_size * element_size; + const size_t group_stride = total_seq_len * hidden_size * element_size; + const size_t batch_stride = + group_width * total_seq_len * hidden_size * element_size; + + // Parallelization strategy: segment by group dimension, consistent with + // thread calculations elsewhere + const size_t num_threads = + std::min(static_cast(group_width), static_cast(16)); + const size_t groups_per_thread = + (group_width + num_threads - 1) / num_threads; + + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + size_t start_group = thread_idx * groups_per_thread; + size_t end_group = std::min(start_group + groups_per_thread, + static_cast(group_width)); + + if (start_group >= static_cast(group_width)) break; + + std::promise promise; + decoder_embedding_futures.push_back(promise.get_future()); + + thread_pool_->schedule( + [start_group, + end_group, + bs, + dst_data, + context_len, + hidden_size, + element_size, + batch_stride, + group_stride, + context_size, + embeddings = perf_cache_.cache_data.decoder_context_embeddings, + dst_tensor = combined_embedding, + promise = std::move(promise)]() mutable { + // Copy context_embedding for specified group range of each batch + for (int64_t b = 0; b < bs; ++b) { + // Optimization: Access corresponding batch embedding directly + // through index + const void* batch_src = embeddings[b].data_ptr(); + auto* batch_dst = + static_cast(dst_data) + b * batch_stride; + + for (size_t g = start_group; g < end_group; ++g) { + std::memcpy( + batch_dst + g * group_stride, batch_src, context_size); + } + } + promise.set_value(); + }); + } + + result_embedding = combined_embedding; + } + + // Task 4: Build sequence length vector - changed to serial execution (very + // time-consuming, ~0.001785ms) + std::vector cu_seq_lens, q_cu_seq_lens; +#if defined(USE_NPU) + // use all prefill; + cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); + q_cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); +#else + cu_seq_lens.reserve(num_sequences + 1); + q_cu_seq_lens.reserve(num_sequences + 1); + cu_seq_lens.push_back(0); + q_cu_seq_lens.push_back(0); + + for (int32_t i = 0; i < num_sequences; ++i) { + cu_seq_lens.push_back(cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + q_cu_seq_lens.push_back(q_cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + } +#endif + + // Task 5: Build encoder_seq_lens_tensor - changed to serial execution (less + // time-consuming) + torch::Tensor encoder_seq_lens_tensor; + if (!cache_data.encoder_seq_lens.empty()) { + // Optimization: Pre-allocate memory and use std::memcpy to avoid clone + // operations + encoder_seq_lens_tensor = torch::empty( + {static_cast(cache_data.encoder_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(encoder_seq_lens_tensor.data_ptr(), + cache_data.encoder_seq_lens.data(), + cache_data.encoder_seq_lens.size() * sizeof(int)); + } + + // Set basic parameters simultaneously (not dependent on asynchronous tasks) + input_params.num_sequences = num_sequences; + input_params.empty_kv_cache = true; + input_params.global_empty_kv_cache = true; + input_params.kv_max_seq_len = seq_len + num_decoder_embeddings; + input_params.q_max_seq_len = seq_len + num_decoder_embeddings; + forward_input.positions = perf_cache_.fixed_positions_tensor; + + // Wait and collect results + forward_input.token_ids = token_ids_future.get(); + // auto encoder_token_ids = encoder_token_ids_future.get(); + + // seq_lens has been changed to serial execution, use the constructed + // variable directly + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.kv_seq_lens = + torch::empty({static_cast(cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.kv_seq_lens.data_ptr(), + cu_seq_lens.data(), + cu_seq_lens.size() * sizeof(int)); + + input_params.q_seq_lens = + torch::empty({static_cast(q_cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.q_seq_lens.data_ptr(), + q_cu_seq_lens.data(), + q_cu_seq_lens.size() * sizeof(int)); + input_params.kv_seq_lens_vec = std::move(cu_seq_lens); + input_params.q_seq_lens_vec = std::move(q_cu_seq_lens); + + // encoder_seq_lens_tensor has been changed to serial execution, use the + // constructed variable directly + if (encoder_seq_lens_tensor.defined()) { + // Set RecModelInputParams encoder data + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + input_params.rec_params->encoder_seq_lens_tensor = + std::move(encoder_seq_lens_tensor); + input_params.rec_params->encoder_seq_lens = cache_data.encoder_seq_lens; + } + input_params.rec_params->encoder_positions = + perf_cache_.fixed_encoder_positions_tensor; + } else { + // Single-threaded execution (original logic) + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + forward_input.token_ids = + torch::empty({static_cast(flatten_tokens_vec.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(forward_input.token_ids.data_ptr(), + flatten_tokens_vec.data(), + flatten_tokens_vec.size() * sizeof(int)); + forward_input.positions = perf_cache_.fixed_positions_tensor; + + if (!encoder_tokens.empty()) { + // Set RecModelInputParams encoder data + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.rec_params->encoder_token_ids = + torch::empty({static_cast(encoder_tokens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.rec_params->encoder_token_ids.data_ptr(), + encoder_tokens.data(), + encoder_tokens.size() * sizeof(int)); + } + input_params.rec_params->encoder_positions = + perf_cache_.fixed_encoder_positions_tensor; + // Pre-allocate and batch fill + std::vector cu_seq_lens, q_cu_seq_lens; +#if defined(USE_NPU) + // use all prefill; + cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); + q_cu_seq_lens.assign(num_sequences, seq_len + num_decoder_embeddings); +#else + cu_seq_lens.reserve(num_sequences + 1); + q_cu_seq_lens.reserve(num_sequences + 1); + cu_seq_lens.push_back(0); + q_cu_seq_lens.push_back(0); + + for (int32_t i = 0; i < num_sequences; ++i) { + cu_seq_lens.push_back(cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + q_cu_seq_lens.push_back(q_cu_seq_lens.back() + seq_len + + num_decoder_embeddings); + } +#endif + + input_params.num_sequences = num_sequences; + input_params.empty_kv_cache = true; + input_params.global_empty_kv_cache = true; + input_params.kv_max_seq_len = seq_len + num_decoder_embeddings; + input_params.q_max_seq_len = seq_len + num_decoder_embeddings; + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.kv_seq_lens = + torch::empty({static_cast(cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.kv_seq_lens.data_ptr(), + cu_seq_lens.data(), + cu_seq_lens.size() * sizeof(int)); + + input_params.q_seq_lens = + torch::empty({static_cast(q_cu_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.q_seq_lens.data_ptr(), + q_cu_seq_lens.data(), + q_cu_seq_lens.size() * sizeof(int)); + + input_params.kv_seq_lens_vec = std::move(cu_seq_lens); + input_params.q_seq_lens_vec = std::move(q_cu_seq_lens); + + if (!cache_data.encoder_seq_lens.empty()) { + // Set RecModelInputParams encoder data + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + + input_params.rec_params->encoder_seq_lens = cache_data.encoder_seq_lens; + + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.rec_params->encoder_seq_lens_tensor = torch::empty( + {static_cast(cache_data.encoder_seq_lens.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy( + input_params.rec_params->encoder_seq_lens_tensor.data_ptr(), + cache_data.encoder_seq_lens.data(), + cache_data.encoder_seq_lens.size() * sizeof(int)); + } + } + + // ========== Parallel processing of independent code blocks ========== + if (thread_pool_ && num_sequences >= THREADPOOL_THRESHOLD) { + // Define promise/future for parallel tasks + std::promise block_tables_promise; + auto block_tables_future = block_tables_promise.get_future(); + + // Task 1: Empty block tables processing - use thread pool (relatively + // time-consuming) + thread_pool_->schedule([&input_params, + num_sequences, + &perf_cache_, + &block_tables_promise]() mutable { + try { + std::vector> empty_block_tables(num_sequences); + util::pad_2d_vector(empty_block_tables, 0); + // Optimization: Use create_2d_tensor_optimized, has special + // optimization for all-zero matrices + input_params.block_tables = + create_2d_tensor(empty_block_tables, torch::kInt); + + std::vector paged_kv_indptr(num_sequences + 1, 0); + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.new_cache_slots = + torch::empty({static_cast(paged_kv_indptr.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.new_cache_slots.data_ptr(), + paged_kv_indptr.data(), + paged_kv_indptr.size() * sizeof(int)); + + block_tables_promise.set_value(); + } catch (...) { + block_tables_promise.set_exception(std::current_exception()); + } + }); + + // Optimization: Merge small tasks into sequential execution to reduce + // thread switching overhead Cross-attention parameter construction - use + // placeholder + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + input_params.rec_params->cross_attn_kv_cu_seq_lens = + torch::zeros({1}, torch::kInt); + input_params.rec_params->cross_attn_kv_cu_seq_lens_vec = {0}; + input_params.rec_params->cross_attn_block_tables = + torch::zeros({1, 1}, torch::kInt); + + // Sampling parameter processing + if (!selected_token_idxes.empty()) { + forward_input.sampling_params.init(sampling_params, + selected_token_idxes, + sample_idxes, + std::vector>{}, + std::vector>{}, + std::vector{}); + } + + // First prefill processing - use placeholder + if (is_first_prefill) { + // Use placeholder instead of complex cross_attn_new_cache_slots + // construction + input_params.rec_params->cross_attn_new_cache_slots = + torch::zeros({1}, torch::kInt); + } + + // Wait for parallel tasks to complete (only block_tables uses thread pool) + block_tables_future.wait(); + } else { + // ========== Non-parallel case: sequential processing ========== + // Optimize empty block tables processing + std::vector> empty_block_tables(num_sequences); + util::pad_2d_vector(empty_block_tables, 0); + // Optimization: Use create_2d_tensor_optimized, has special optimization + // for all-zero matrices + input_params.block_tables = + create_2d_tensor(empty_block_tables, torch::kInt); + + std::vector paged_kv_indptr(num_sequences + 1, 0); + // Optimization: Use torch::empty+std::memcpy instead of + // torch::from_blob().clone() + input_params.new_cache_slots = + torch::empty({static_cast(paged_kv_indptr.size())}, + torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true)); + std::memcpy(input_params.new_cache_slots.data_ptr(), + paged_kv_indptr.data(), + paged_kv_indptr.size() * sizeof(int)); + + // ========== Cross-attention parameter construction (using placeholder) + // ========== Use placeholder tensor instead of actual data + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + input_params.rec_params->cross_attn_kv_cu_seq_lens = + torch::zeros({1}, torch::kInt); + input_params.rec_params->cross_attn_kv_cu_seq_lens_vec = {0}; + + // Use placeholder tensor instead of actual data + input_params.rec_params->cross_attn_block_tables = + torch::zeros({1, 1}, torch::kInt); + + // ========== Optimize sampling parameter processing ========== + if (!selected_token_idxes.empty()) { + forward_input.sampling_params.init(sampling_params, + selected_token_idxes, + sample_idxes, + std::vector>{}, + std::vector>{}, + std::vector{}); + } + + // ========== First prefill processing (using placeholder) ========== + if (is_first_prefill) { + // Use placeholder tensor instead of actual data + input_params.rec_params->cross_attn_new_cache_slots = + torch::zeros({1}, torch::kInt); + } + } + + // ========== Common parameter settings ========== + // Batch set other parameters + input_params.embedding_ids.assign(num_sequences, 0); + +#if defined(USE_NPU) + auto prefill_indices = util::find_ones_indices(input_params.q_seq_lens_vec); + input_params.decode_seq_range = + std::make_pair(0, static_cast(flatten_tokens_vec.size())); +#else + input_params.decode_seq_range = { + 0, static_cast(flatten_tokens_vec.size())}; +#endif + + // Rec model parameters + if (!input_params.rec_params.has_value()) { + input_params.rec_params = RecModelInputParams{}; + } + + input_params.rec_params->rec_stage = RecModelInputParams::RecStage::PREFILL; + input_params.rec_params->is_hybrid_mode = false; + input_params.rec_params->has_encoder_output = true; + input_params.rec_params->is_first_prefill = is_first_prefill; + input_params.rec_params->bs = bs; + input_params.rec_params->group_width = group_width; + input_params.rec_params->seq_len = seq_len; + input_params.rec_params->encoder_max_seq_len = + cache_data.encoder_seq_lens.empty() + ? 0 + : *std::max_element(cache_data.encoder_seq_lens.begin(), + cache_data.encoder_seq_lens.end()); + + input_params.rec_params->generated_tokens = std::move(generated_tokens); + + // Process sparse_embedding: Efficiently concatenate from cache_data + if (!perf_cache_.cache_data.encoder_sparse_embeddings.empty()) { + // Use torch::cat for efficient concatenation, concatenate along dim=0 + input_params.rec_params->encoder_sparse_embedding = + torch::cat(perf_cache_.cache_data.encoder_sparse_embeddings, /*dim=*/0); + } + + if (!perf_cache_.cache_data.decoder_context_embeddings.empty()) { + // Get group_width + int64_t group_width = input_params.rec_params->group_width; + if (group_width == 1 && seq_len == 0) { + // Optimization: When bs==1, directly use the first embedding to avoid + // unnecessary torch::cat + if (bs == 1) { + input_params.rec_params->decoder_context_embedding = + perf_cache_.cache_data.decoder_context_embeddings[0]; + } else { + // Use torch::cat for efficient concatenation, concatenate along dim=0 + auto original_context_embedding = torch::cat( + perf_cache_.cache_data.decoder_context_embeddings, /*dim=*/0); + input_params.rec_params->decoder_context_embedding = + original_context_embedding; + } + } else if (group_width == 1 && seq_len > 0) { + // Handle the scenario where group_width==1 and seq_len>0 + // Get information from the first embedding + const auto& first_embedding = + perf_cache_.cache_data.decoder_context_embeddings[0]; + auto original_shape = first_embedding.sizes(); + int64_t context_len = original_shape[0]; + int64_t hidden_size = original_shape[1]; + int64_t total_seq_len = context_len + seq_len; + + // Allocate a tensor of shape {bs, 1, total_seq_len, hidden_size}, + // optimized with pinned memory + auto options = torch::TensorOptions() + .dtype(first_embedding.dtype()) + .device(first_embedding.device()) + .pinned_memory(true) + .memory_format(torch::MemoryFormat::Contiguous); + auto combined_embedding = + torch::empty({bs, 1, total_seq_len, hidden_size}, options); + + // Single-threaded copy of context_len portion of data + void* dst_data = combined_embedding.data_ptr(); + const size_t element_size = first_embedding.element_size(); + const size_t context_size = context_len * hidden_size * element_size; + const size_t batch_stride = total_seq_len * hidden_size * element_size; + + // Copy context_embedding for each batch + for (int64_t b = 0; b < bs; ++b) { + const void* batch_src = + perf_cache_.cache_data.decoder_context_embeddings[b].data_ptr(); + auto* batch_dst = static_cast(dst_data) + b * batch_stride; + std::memcpy(batch_dst, batch_src, context_size); + } + input_params.rec_params->decoder_context_embedding = combined_embedding; + } else { + for (auto& future : decoder_embedding_futures) { + future.get(); + } + input_params.rec_params->decoder_context_embedding = + std::move(result_embedding); + } + } + + return forward_input; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/batch/rec_batch_input_builder.h b/xllm/core/framework/batch/rec_batch_input_builder.h new file mode 100644 index 00000000..5bcf1347 --- /dev/null +++ b/xllm/core/framework/batch/rec_batch_input_builder.h @@ -0,0 +1,134 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include + +#include "batch_input_builder.h" +#include "framework/model/model_args.h" +#include "framework/model/model_input_params.h" +#include "framework/request/mm_data.h" +#include "framework/request/sequence.h" +#include "framework/request/sequences_group.h" +#include "runtime/forward_params.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecBatchInputBuilder : public BatchInputBuilder { + public: + explicit RecBatchInputBuilder( + const std::vector>& sequence_groups, + const std::vector& allowed_max_tokens, + const std::vector& input_embeddings_vec, + const std::vector& mm_data_vec, + const std::vector* copy_in_cache_block_infos, + const std::vector* copy_out_cache_block_infos, + std::vector* swap_cache_block_infos, + const ModelArgs* args, + ThreadPool* thread_pool = nullptr); + + protected: + // Provide protected access methods for subclasses - modified to access + // parent's protected members + const std::vector>& get_sequence_groups() + const { + return sequence_groups_; + } + const std::vector& get_allowed_max_tokens() const { + return allowed_max_tokens_; + } + const std::vector& get_input_embeddings_vec() const { + return input_embeddings_vec_; + } + const std::vector& get_mm_data_vec() const { return mm_data_vec_; } + const std::vector* get_copy_in_cache_block_infos() const { + return copy_in_cache_block_infos_; + } + const std::vector* get_copy_out_cache_block_infos() const { + return copy_out_cache_block_infos_; + } + std::vector* get_swap_cache_block_infos() const { + return swap_cache_block_infos_; + } + const ModelArgs* get_args() const { return args_; } + ThreadPool* get_thread_pool() const { return thread_pool_; } + + public: + // Main public interface + ForwardInput build_rec_forward_input(uint32_t num_decoding_tokens, + uint32_t min_decoding_batch_size); + + private: + // Helper method to extract sequences from groups + static std::vector extract_sequences_from_groups( + const std::vector>& sequence_groups); + + // Member variables - only keep sequence_groups_, others inherited from parent + // class + const std::vector>& sequence_groups_; + + // High performance cache system + struct HighPerformanceCache { + // Memory pool - avoid frequent allocation/deallocation + struct MemoryPool { + std::vector> int32_pools; + size_t pool_index = 0; + + std::vector& getInt32Vector(size_t reserve_size = 0) { + if (pool_index >= int32_pools.size()) { + int32_pools.emplace_back(); + } + auto& vec = int32_pools[pool_index++]; + vec.clear(); + if (reserve_size > 0) vec.reserve(reserve_size); + return vec; + } + + void reset() { pool_index = 0; } + }; + + // Cache data structure + struct CacheData { + std::vector encoder_tokens; + std::vector encoder_seq_lens; + std::vector encoder_sparse_embeddings; + std::vector decoder_context_embeddings; + }; + + // Pre-created constant tensors + torch::Tensor fixed_positions_tensor; + torch::Tensor fixed_encoder_positions_tensor; + torch::Tensor empty_tensor; + + MemoryPool memory_pool; + CacheData cache_data; + + HighPerformanceCache() { + // Pre-create commonly used tensors to avoid repeated creation + fixed_positions_tensor = torch::tensor({0}, torch::kInt); + fixed_encoder_positions_tensor = torch::tensor({0}, torch::kInt); + empty_tensor = torch::tensor(std::vector{}, torch::kInt); + } + }; + + static HighPerformanceCache perf_cache_; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/model/model_args.h b/xllm/core/framework/model/model_args.h index ff67b22b..79de69ee 100755 --- a/xllm/core/framework/model/model_args.h +++ b/xllm/core/framework/model/model_args.h @@ -346,7 +346,7 @@ struct ModelArgs { PROPERTY(std::vector, axes_dims_rope) = {}; PROPERTY(int64_t, num_single_layers) = 0; - // t5 related args + // rec related args PROPERTY(int64_t, d_model) = 0; PROPERTY(int64_t, num_layers) = 0; PROPERTY(int64_t, d_kv) = 0; @@ -356,6 +356,17 @@ struct ModelArgs { PROPERTY(int64_t, relative_attention_num_buckets) = 0; PROPERTY(int64_t, relative_attention_max_distance) = 0; + PROPERTY(int64_t, n_encoder_layers) = 0; + PROPERTY(int64_t, decoder_head_dim) = 0; + PROPERTY(int64_t, decoder_n_heads) = 0; + PROPERTY(std::optional, decoder_n_kv_heads); + PROPERTY(bool, use_absolute_position_embedding) = false; + + PROPERTY(bool, use_moe) = false; + PROPERTY(std::string, moe_score_func); + PROPERTY(float, moe_route_scale) = 1.0f; + PROPERTY(bool, moe_use_shared_experts) = false; + // scheduler related args PROPERTY(int64_t, num_train_timesteps) = 0; PROPERTY(int64_t, shift) = 0; diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 6669baaa..8376fd4c 100755 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -26,6 +26,137 @@ limitations under the License. #include "util/tensor_helper.h" namespace xllm { + +// Rec model specific input parameters +struct RecModelInputParams { + // Rec model specific parameters + + enum class RecStage { + PREFILL, // Prefill stage + DECODE // Decode stage + }; + + RecStage rec_stage = RecStage::PREFILL; + bool is_hybrid_mode = false; + // Flag to distinguish encoder vs decoder forward calls + bool is_encoder_forward = false; + // For Rec decoder cross-attention + bool has_encoder_output = false; + // Length of encoder output sequence for each sequence + std::vector encoder_seq_lens; + // Pre-constructed tensor for encoder_seq_lens + torch::Tensor encoder_seq_lens_tensor; + // max encoder seq len + int32_t encoder_max_seq_len = 0; + + // Additional parameters needed by rec_batch_input_builder + bool is_first_prefill = true; + int32_t bs = 0; // batch size + int32_t group_width = 0; + int32_t seq_len = 0; + std::vector> generated_tokens; + torch::Tensor encoder_sparse_embedding; + torch::Tensor decoder_context_embedding; + + // Separate KV cache parameters for different attention types + // For Rec decoder: self_attn uses growing cache, cross_attn uses fixed cache + torch::Tensor cross_attn_kv_cu_seq_lens; // KV lengths for cross-attention + torch::Tensor cross_attn_new_cache_slots; // Cache slots for cross-attention + torch::Tensor cross_attn_block_tables; // Block tables for cross-attention + std::vector cross_attn_kv_cu_seq_lens_vec; + + torch::Tensor encoder_token_ids; + // Rec encoder positions + torch::Tensor encoder_positions; + + RecModelInputParams to(const c10::Device& device) const { + RecModelInputParams result = *this; + + // Move tensors to the specified device + if (encoder_seq_lens_tensor.defined()) { + result.encoder_seq_lens_tensor = encoder_seq_lens_tensor.to(device); + } + + if (encoder_sparse_embedding.defined()) { + result.encoder_sparse_embedding = encoder_sparse_embedding.to(device); + } + + if (decoder_context_embedding.defined()) { + result.decoder_context_embedding = decoder_context_embedding.to(device); + } + + if (cross_attn_kv_cu_seq_lens.defined()) { + result.cross_attn_kv_cu_seq_lens = cross_attn_kv_cu_seq_lens.to(device); + } + + if (cross_attn_new_cache_slots.defined()) { + result.cross_attn_new_cache_slots = cross_attn_new_cache_slots.to(device); + } + + if (cross_attn_block_tables.defined()) { + result.cross_attn_block_tables = cross_attn_block_tables.to(device); + } + + if (encoder_token_ids.defined()) { + result.encoder_token_ids = encoder_token_ids.to(device); + } + + if (encoder_positions.defined()) { + result.encoder_positions = encoder_positions.to(device); + } + + return result; + } + + void print() const { + LOG(INFO) << "RecModelInputParams:" + << " rec_stage: " + << (rec_stage == RecStage::PREFILL ? "PREFILL" : "DECODE") + << " is_hybrid_mode: " << is_hybrid_mode + << " is_encoder_forward: " << is_encoder_forward + << " has_encoder_output: " << has_encoder_output + << " encoder_max_seq_len: " << encoder_max_seq_len + << " is_first_prefill: " << is_first_prefill << " bs: " << bs + << " group_width: " << group_width << " seq_len: " << seq_len + << " encoder_seq_lens size: " << encoder_seq_lens.size() + << " cross_attn_kv_cu_seq_lens_vec size: " + << cross_attn_kv_cu_seq_lens_vec.size() + << " generated_tokens size: " << generated_tokens.size(); + + // Print tensor shapes if defined + if (encoder_seq_lens_tensor.defined()) { + LOG(INFO) << " encoder_seq_lens_tensor shape: " + << encoder_seq_lens_tensor.sizes(); + } + if (encoder_sparse_embedding.defined()) { + LOG(INFO) << " encoder_sparse_embedding shape: " + << encoder_sparse_embedding.sizes(); + } + if (decoder_context_embedding.defined()) { + LOG(INFO) << " decoder_context_embedding shape: " + << decoder_context_embedding.sizes(); + } + if (cross_attn_kv_cu_seq_lens.defined()) { + LOG(INFO) << " cross_attn_kv_cu_seq_lens shape: " + << cross_attn_kv_cu_seq_lens.sizes(); + } + if (cross_attn_new_cache_slots.defined()) { + LOG(INFO) << " cross_attn_new_cache_slots shape: " + << cross_attn_new_cache_slots.sizes(); + } + if (cross_attn_block_tables.defined()) { + LOG(INFO) << " cross_attn_block_tables shape: " + << cross_attn_block_tables.sizes(); + } + if (encoder_token_ids.defined()) { + LOG(INFO) << " encoder_token_ids shape: " << encoder_token_ids.sizes(); + } + if (encoder_positions.defined()) { + LOG(INFO) << " encoder_positions shape: " << encoder_positions.sizes(); + } + } +}; + struct CacheBlockInfo { int32_t device_block_id = 0; int32_t host_block_id = 0; @@ -97,10 +228,14 @@ struct ModelInputParams { // Copy graph_buffer to device params.graph_buffer = safe_to(graph_buffer, device, true); + // Copy optional Rec parameters if present + if (rec_params.has_value()) { + params.rec_params = rec_params->to(device); + } return params; } - void print() const { + virtual void print() const { LOG(INFO) << "ModelInputParams: empty_kv_cache is " << empty_kv_cache << " , global_empty_kv_cache is " << global_empty_kv_cache << " , num_sequences is " << num_sequences @@ -116,6 +251,10 @@ struct ModelInputParams { print_tensor(block_tables, "ModelInputParams: block_tables", 4); LOG(INFO) << "ModelInputParams: dp_global_token_nums is " << dp_global_token_nums; + if (rec_params.has_value()) { + LOG(INFO) << "ModelInputParams: has rec_params"; + rec_params->print(); + } } // whether the kv-cache is empty for all sequences. bool empty_kv_cache = true; @@ -201,6 +340,12 @@ struct ModelInputParams { // Graph execution buffer for temporary tensor storage // Used by ACL Graph Executor to avoid repeated memory allocation torch::Tensor graph_buffer; + + // Optional Rec model specific parameters + std::optional rec_params; + + // Helper function to check if this is a Rec model + bool is_rec_model() const { return rec_params.has_value(); } }; } // namespace xllm diff --git a/xllm/core/framework/prefix_cache/prefix_cache.cpp b/xllm/core/framework/prefix_cache/prefix_cache.cpp index fac8ccfb..3e73c5d4 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache.cpp +++ b/xllm/core/framework/prefix_cache/prefix_cache.cpp @@ -15,7 +15,6 @@ limitations under the License. #include "prefix_cache.h" -#include #include #include #include @@ -25,36 +24,10 @@ limitations under the License. #include "common/global_flags.h" #include "common/metrics.h" +#include "util/hash_util.h" namespace xllm { -void murmur_hash3(const uint8_t* pre_hash_value, - const Slice& token_ids, - uint8_t* hash_value) { - if (pre_hash_value == nullptr) { - MurmurHash3_x64_128(reinterpret_cast(token_ids.data()), - sizeof(int32_t) * token_ids.size(), - FLAGS_murmur_hash3_seed, - hash_value); - } else { - uint8_t key[1024]; - - int32_t data_len = - sizeof(int32_t) * token_ids.size() + MURMUR_HASH3_VALUE_LEN; - CHECK_GT(sizeof(key), data_len) << "key size is too small"; - - memcpy(key, pre_hash_value, MURMUR_HASH3_VALUE_LEN); - memcpy(key + MURMUR_HASH3_VALUE_LEN, - reinterpret_cast(token_ids.data()), - sizeof(int32_t) * token_ids.size()); - - MurmurHash3_x64_128(reinterpret_cast(key), - data_len, - FLAGS_murmur_hash3_seed, - hash_value); - } -} - std::vector PrefixCache::match( const Slice& token_ids, const Slice& existed_shared_blocks) { diff --git a/xllm/core/framework/prefix_cache/prefix_cache.h b/xllm/core/framework/prefix_cache/prefix_cache.h index fc778419..9db26b83 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache.h +++ b/xllm/core/framework/prefix_cache/prefix_cache.h @@ -39,10 +39,6 @@ inline size_t round_down(size_t n, size_t multiple) { return (n / multiple) * multiple; } -void murmur_hash3(const uint8_t* pre_hash_value, - const Slice& token_ids, - uint8_t* hash_value); - class PrefixCache { public: PrefixCache(const PrefixCache&) = delete; diff --git a/xllm/core/framework/prefix_cache/prefix_cache_test.cpp b/xllm/core/framework/prefix_cache/prefix_cache_test.cpp index d0b0ca7e..a9fa1466 100644 --- a/xllm/core/framework/prefix_cache/prefix_cache_test.cpp +++ b/xllm/core/framework/prefix_cache/prefix_cache_test.cpp @@ -7,6 +7,7 @@ #include #include "framework/block/block_manager_impl.h" +#include "util/hash_util.h" namespace xllm { diff --git a/xllm/core/framework/request/request.cpp b/xllm/core/framework/request/request.cpp index 84f5a71e..836d4ce7 100644 --- a/xllm/core/framework/request/request.cpp +++ b/xllm/core/framework/request/request.cpp @@ -56,6 +56,8 @@ void Request::create_sequences_group() { sequence_params.best_of = state_.best_of; sequence_params.streaming = state_.stream; sequence_params.enable_schedule_overlap = state_.enable_schedule_overlap; + sequence_params.is_rec_model = state_.is_rec_model; + sequence_params.bos_token_id = state_.bos_token_id; sequence_params.sampling_param = &(state_.sampling_param); sequence_params.stopping_checker = &(state_.stopping_checker); sequences_group_ = std::make_unique(state_.prompt, diff --git a/xllm/core/framework/request/request_output.h b/xllm/core/framework/request/request_output.h index c4781ac3..2527bc88 100644 --- a/xllm/core/framework/request/request_output.h +++ b/xllm/core/framework/request/request_output.h @@ -66,6 +66,9 @@ struct SequenceOutput { // the token ids of the generated text. std::vector token_ids; + // item_id for rec. + std::optional item_ids; + // the reason the sequence finished. std::optional finish_reason; diff --git a/xllm/core/framework/request/request_state.h b/xllm/core/framework/request/request_state.h index 5ff04322..2dc52032 100644 --- a/xllm/core/framework/request/request_state.h +++ b/xllm/core/framework/request/request_state.h @@ -137,6 +137,12 @@ struct RequestState final { bool enable_schedule_overlap = false; + // rec model specific flag + bool is_rec_model = false; + + // The bos token id of the model. + int32_t bos_token_id = 0; + // The thread id of the thread pool in the response handler to ensure that // stream responses for the same request are executed sequentially during // multi-threaded stream processing. diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 9a7d2ce9..a26e2d97 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -34,6 +34,15 @@ limitations under the License. namespace xllm { +// Number of decoder BOS tokens to add for rec models +static constexpr size_t kDecoderBosTokenCount = 1; +static constexpr size_t kDecoderMaxTokenCount = 4; + +// rec model specific: static constants for embedding names +const std::string Sequence::ENCODER_SPARSE_EMBEDDING_NAME = "sparse_embedding"; +const std::string Sequence::DECODER_CONTEXT_EMBEDDING_NAME = + "decoder_context_embedding"; + Sequence::Sequence(size_t index, const std::vector& prompt_token_ids, torch::Tensor input_embedding, @@ -44,25 +53,72 @@ Sequence::Sequence(size_t index, mm_data_(mm_data), latest_generate_time_(absl::Now()), sequence_params_(seq_params), - decoder_(std::move(decoder)) { - CHECK(!prompt_token_ids.empty()) << "empty prompt token ids"; - auto capacity = sequence_params_.seq_capacity; - CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small"; + decoder_(std::move(decoder)), + is_rec_model_(seq_params.is_rec_model) { + // rec model specific: handle encoder tokens and decoder embeddings + if (is_rec_model_) { + // For rec model, treat prompt_token_ids as encoder_tokens + if (prompt_token_ids.size() > 0) { + encoder_tokens_.resize(prompt_token_ids.size()); + for (size_t i = 0; i < prompt_token_ids.size(); ++i) { + encoder_tokens_[i] = prompt_token_ids[i]; + } + num_encoder_tokens_ = prompt_token_ids.size(); + } else { + // If no prompt tokens, check for encoder sparse embedding in mm_data + auto encoder_sparse_embedding = + mm_data_.get(ENCODER_SPARSE_EMBEDDING_NAME); + CHECK(encoder_sparse_embedding.has_value()) + << "encoder sparse embedding not found in mm_data"; + num_encoder_tokens_ = encoder_sparse_embedding.value().size(0); + } - num_prompt_tokens_ = prompt_token_ids.size(); - volatile_num_prompt_tokens_ = num_prompt_tokens_; - tokens_.resize(capacity); + // Check if decoder context embedding exists in mm_data + auto decoder_context_embedding = + mm_data_.get(DECODER_CONTEXT_EMBEDDING_NAME); + auto capacity = kDecoderMaxTokenCount; + if (decoder_context_embedding.has_value()) { + // Use context embedding replacing bos + prompt + num_prompt_tokens_ = 0; + num_decoder_embeddings_ = decoder_context_embedding.value().size(0); + capacity = num_decoder_embeddings_ + capacity - kDecoderBosTokenCount; + } else { + // Only BOS token for decoder + num_prompt_tokens_ = kDecoderBosTokenCount; // kDecoderBosTokenCount + } + tokens_.resize(capacity); + for (size_t i = 0; i < num_prompt_tokens_; ++i) { + tokens_[num_tokens_++] = sequence_params_.bos_token_id; + token_to_count_map_[sequence_params_.bos_token_id]++; + } - // init logprob state - logprob_state_ = std::make_unique(num_prompt_tokens_, capacity); + volatile_num_prompt_tokens_ = num_prompt_tokens_; + input_embedding_ = input_embedding; + cur_generated_token_idx_ = num_prompt_tokens_; + // init logprob state + logprob_state_ = + std::make_unique(num_prompt_tokens_, capacity); + // rec only use rank 0 now. + set_dp_rank(0); + } else { + CHECK(!prompt_token_ids.empty()) << "empty prompt token ids"; + auto capacity = sequence_params_.seq_capacity; + CHECK_GT(capacity, prompt_token_ids.size()) << "capacity too small"; + num_prompt_tokens_ = prompt_token_ids.size(); + tokens_.resize(capacity); + // add the prompt tokens + for (const auto token_id : prompt_token_ids) { + tokens_[num_tokens_++] = token_id; + token_to_count_map_[token_id]++; + } - // add the prompt tokens - for (const auto token_id : prompt_token_ids) { - tokens_[num_tokens_++] = token_id; - token_to_count_map_[token_id]++; + volatile_num_prompt_tokens_ = num_prompt_tokens_; + input_embedding_ = input_embedding; + cur_generated_token_idx_ = num_prompt_tokens_; + // init logprob state + logprob_state_ = + std::make_unique(num_prompt_tokens_, capacity); } - input_embedding_ = input_embedding; - cur_generated_token_idx_ = num_prompt_tokens_; } Sequence::Sequence(const Sequence& other) @@ -84,6 +140,10 @@ Sequence::Sequence(const Sequence& other) num_tokens_(other.num_tokens_), token_to_count_map_(other.token_to_count_map_), num_prompt_tokens_(other.num_prompt_tokens_), + num_encoder_tokens_(other.num_encoder_tokens_), + num_decoder_embeddings_(other.num_decoder_embeddings_), + encoder_tokens_(other.encoder_tokens_), + is_rec_model_(other.is_rec_model_), volatile_num_prompt_tokens_(other.volatile_num_prompt_tokens_), embedding_id_(other.embedding_id_), finished_(other.finished_), @@ -101,8 +161,10 @@ void Sequence::append_token(const Token& token) { CHECK_LT(num_tokens_, tokens_.size()) << "exceed the token capacity of the sequence"; CHECK(!finished_) << "cannot append token to a finished sequence"; - CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_prefill_stage()) - << "cannot append token to a prefill sequence"; + if (!is_rec_model()) { + CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_prefill_stage()) + << "cannot append token to a prefill sequence"; + } if (!sequence_params_.enable_schedule_overlap) { // check if the token is the first token after the prompt @@ -249,6 +311,15 @@ std::optional Sequence::generate_streaming_output( AUTO_COUNTER(detokenization_latency_seconds_stream); const auto ids = Slice(tokens_, size); + // For rec model, return token_ids directly without decode + if (is_rec_model()) { + const size_t start = num_prompt_tokens_; + SequenceOutput output; + output.index = index_; + output.token_ids = ids.slice(start, size); + return output; + } + // record the start index of token ids const size_t start = decoder_.output_offset(); auto delta = decoder_.decode(ids, tokenizer); @@ -332,6 +403,21 @@ SequenceOutput Sequence::generate_output(const Tokenizer& tokenizer) { } } + SequenceOutput output; + output.index = index_; + + if (output_embedding_.defined()) { + output.embedding = output_embedding_; + } + + if (finish_reason_ != FinishReason::NONE) { + output.finish_reason = finish_reason_.to_string(); + } + + if (is_rec_model()) { + output.token_ids = ids.slice(num_prompt_tokens_, size); + return output; + } // record the start index of token ids const size_t start = decoder_.output_offset(); @@ -348,16 +434,7 @@ SequenceOutput Sequence::generate_output(const Tokenizer& tokenizer) { ss << decoder_.decode(ids.slice(0, end), tokenizer); } - SequenceOutput output; - output.index = index_; output.text = ss.str(); - if (output_embedding_.defined()) { - output.embedding = output_embedding_; - } - - if (finish_reason_ != FinishReason::NONE) { - output.finish_reason = finish_reason_.to_string(); - } const size_t end = decoder_.output_offset(); output.token_ids = ids.slice(start, end); @@ -399,6 +476,10 @@ bool Sequence::finished() const { return finished_; } + if (is_rec_model() && num_tokens_ == num_prompt_tokens_) { + return false; + } + // Embedding sequence never be finished until it updates its embeddings if (finish_status_invalidated_ && sequence_params_.sampling_param->is_embeddings) { @@ -456,4 +537,12 @@ Slice Sequence::get_generated_tokens() const { return {tokens_.data(), 0}; } +void Sequence::finish() { + finished_ = true; + finish_status_invalidated_ = false; + if (finish_reason_ == FinishReason::NONE) { + finish_reason_ = FinishReason::STOP; + } +} + } // namespace xllm diff --git a/xllm/core/framework/request/sequence.h b/xllm/core/framework/request/sequence.h index 846c037b..753f8358 100644 --- a/xllm/core/framework/request/sequence.h +++ b/xllm/core/framework/request/sequence.h @@ -65,6 +65,9 @@ struct SequenceParams { // enable_schedule_overlap or not. default = false. bool enable_schedule_overlap = false; + // whether this is a rec model. default = false. + bool is_rec_model = false; + int32_t bos_token_id = 0; // sampling params // reference from request RequestSamplingParam* sampling_param; // not owned @@ -192,6 +195,9 @@ class Sequence final { void close() { closed_ = true; } bool is_closed() const { return closed_; } + // finish the sequence by setting finished status and reason + void finish(); + // time between two tokens int64_t tbt(const absl::Time& now); // set sequence ttft @@ -265,6 +271,22 @@ class Sequence final { // get sequence id int32_t seq_id() const { return seq_id_; } + // rec model specific: get encoder tokens + const std::vector& encoder_tokens() const { return encoder_tokens_; } + + // rec model specific: get encoder sequence length + size_t encoder_seq_len() const { return num_encoder_tokens_; } + + // rec model specific: get number of decoder embeddings + size_t num_decoder_embeddings() const { return num_decoder_embeddings_; } + + // rec model specific: check if this is a rec model + bool is_rec_model() const { return is_rec_model_; } + + // rec model specific: static constants for embedding names + static const std::string ENCODER_SPARSE_EMBEDDING_NAME; + static const std::string DECODER_CONTEXT_EMBEDDING_NAME; + private: // the index of the sequence in the request size_t index_ = 0; @@ -312,6 +334,18 @@ class Sequence final { // the length of the prompt tokens size_t num_prompt_tokens_ = 0; + // rec model specific: number of encoder tokens + size_t num_encoder_tokens_ = 0; + + // rec model specific: number of decoder embeddings + size_t num_decoder_embeddings_ = 0; + + // rec model specific: encoder tokens storage + std::vector encoder_tokens_; + + // rec model specific: whether this is a rec model + bool is_rec_model_ = false; + // NOTE: MUST FIXME Later // record all tokens num in last turn when the request is // interrupted due to the lack of kv cache capacity. diff --git a/xllm/core/framework/request/sequences_group.cpp b/xllm/core/framework/request/sequences_group.cpp index 7bbce9af..d7c232f5 100644 --- a/xllm/core/framework/request/sequences_group.cpp +++ b/xllm/core/framework/request/sequences_group.cpp @@ -174,7 +174,7 @@ void SequencesGroup::process_beam_search() { if (!check_beam_search()) { return; } - + Timer timer; size_t beam_width = sequence_params_.sampling_param->beam_width; size_t seq_size = sequences_.size(); size_t topk = sequence_params_.sampling_param->top_logprobs; @@ -290,6 +290,14 @@ void SequencesGroup::process_beam_search() { CHECK_EQ(sequences_.size(), beam_width); update_for_sequence(0, beam_width); + HISTOGRAM_OBSERVE(expand_beam_latency_microseconds, + timer.elapsed_microseconds()); +} + +void SequencesGroup::finish() { + for (auto& sequence : sequences_) { + sequence->finish(); + } } } // namespace xllm diff --git a/xllm/core/framework/request/sequences_group.h b/xllm/core/framework/request/sequences_group.h index 1ed5ceca..5d0c9174 100644 --- a/xllm/core/framework/request/sequences_group.h +++ b/xllm/core/framework/request/sequences_group.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "common.pb.h" +#include "common/metrics.h" #include "core/framework/sampling/sampling_params.h" #include "mm_data.h" #include "sequence.h" @@ -55,11 +56,17 @@ class SequencesGroup { } std::vector>& sequences() { return sequences_; } + const std::vector>& sequences() const { + return sequences_; + } int32_t dp_rank() { return sequences_[0]->dp_rank(); } bool is_prefill_stage() const { return sequences_[0]->is_prefill_stage(); } + // finish all sequences in the group + void finish(); + private: void add(); diff --git a/xllm/core/framework/sampling/CMakeLists.txt b/xllm/core/framework/sampling/CMakeLists.txt index 764157d0..6d1a07fc 100644 --- a/xllm/core/framework/sampling/CMakeLists.txt +++ b/xllm/core/framework/sampling/CMakeLists.txt @@ -10,17 +10,20 @@ cc_library( rejection_sampler.h sampler.h beam_searcher.h + valid_path_filter.h SRCS sampling_params.cpp logits_utils.cpp rejection_sampler.cpp sampler.cpp beam_searcher.cpp + valid_path_filter.cpp DEPS :common glog::glog torch :kernels + :util $<$:xllm_ops> ) @@ -31,12 +34,14 @@ cc_test( rejection_sampler_test.cpp rejection_sampler.cpp sampling_params_test.cpp + valid_path_filter_test.cpp DEPS absl::strings GTest::gtest_main :flags :sampler glog::glog + torch ) target_link_libraries(sampler_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto leveldb::leveldb ZLIB::ZLIB protobuf::libprotobuf) target_link_libraries(sampler_test diff --git a/xllm/core/framework/sampling/sampling_params.cpp b/xllm/core/framework/sampling/sampling_params.cpp index 4ed8cb77..1165bf5c 100644 --- a/xllm/core/framework/sampling/sampling_params.cpp +++ b/xllm/core/framework/sampling/sampling_params.cpp @@ -24,6 +24,8 @@ limitations under the License. #include #include +#include "common/global_flags.h" + namespace xllm { void SamplingParameters::init( @@ -35,9 +37,11 @@ void SamplingParameters::init( const std::vector& unique_token_lens_vec) { CHECK_EQ(req_sampling_params.size(), selected_token_idxes.size()); CHECK_GE(req_sampling_params.size(), sample_idxes.size()); - CHECK_EQ(req_sampling_params.size(), unique_token_ids_vec.size()); - CHECK_EQ(req_sampling_params.size(), unique_token_counts_vec.size()); - CHECK_EQ(req_sampling_params.size(), unique_token_lens_vec.size()); + if (FLAGS_backend != "rec") { + CHECK_EQ(req_sampling_params.size(), unique_token_ids_vec.size()); + CHECK_EQ(req_sampling_params.size(), unique_token_counts_vec.size()); + CHECK_EQ(req_sampling_params.size(), unique_token_lens_vec.size()); + } std::vector frequency_penalties; std::vector presence_penalties; diff --git a/xllm/core/framework/sampling/valid_path_filter.cpp b/xllm/core/framework/sampling/valid_path_filter.cpp new file mode 100644 index 00000000..7784ea0e --- /dev/null +++ b/xllm/core/framework/sampling/valid_path_filter.cpp @@ -0,0 +1,269 @@ +#include "valid_path_filter.h" + +#include + +#include +#include +#include +#include +#include + +#include "util/env_var.h" +#include "util/hash_util.h" +#include "util/slice.h" +#include "util/tensor_helper.h" +#include "util/timer.h" + +namespace xllm { + +namespace { + +void parse_valid_path_filter_file( + std::vector>& tokens_list, + const std::string& valid_path_filter_file) { + if (valid_path_filter_file.empty()) { + LOG(WARNING) << "Get empty vaild path filter file: " + << valid_path_filter_file; + return; + } + if (!std::filesystem::exists(valid_path_filter_file)) { + LOG(ERROR) << "Failed to find vaild path filter file: " + << valid_path_filter_file; + return; + } + std::ifstream ifs(valid_path_filter_file, std::ios::binary | std::ios::ate); + if (!ifs.is_open()) { + LOG(ERROR) << "Failed to load vaild path filter file: " + << valid_path_filter_file; + return; + } + + const size_t file_size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + + const int elements_per_line = 3; + const size_t elements_size = elements_per_line * sizeof(int32_t); + const size_t line_size = sizeof(int64_t) + elements_size; + const size_t estimated_lines = (file_size + line_size - 1) / line_size; + + tokens_list.reserve(estimated_lines); + + int64_t item_id; + std::vector buffer(elements_per_line); + while (ifs.read(reinterpret_cast(&item_id), sizeof(int64_t)) && + ifs.read(reinterpret_cast(buffer.data()), elements_size)) { + tokens_list.emplace_back(buffer.begin(), buffer.end()); + } + LOG(INFO) << "ValidPathFilter parse tokens list size:" << tokens_list.size(); + + if (ifs.gcount() != 0 && ifs.gcount() != line_size) { + LOG(ERROR) << "Possibly containing incomplete lines : " + << valid_path_filter_file; + return; + } +} +} // namespace + +float ValidPathFilter::pre_mask_factor_ = -10000.0f; + +ValidPathFilter::ValidPathFilter(const std::string valid_path_filter_file, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device) + : vocab_size_(vocab_size), dtype_(dtype), device_(device) { + std::vector> tokens_list; + Timer timer; + parse_valid_path_filter_file(tokens_list, valid_path_filter_file); + init_cached_mask(tokens_list, vocab_size); + LOG(INFO) << " ValidPathFilter generate " << cached_sparse_mask_.size() + << " key for " << tokens_list.size() << " items which took " + << timer.elapsed_seconds() << " secs."; +} + +ValidPathFilter::ValidPathFilter( + const std::vector>& tokens_list, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device) + : vocab_size_(vocab_size), dtype_(dtype), device_(device) { + init_cached_mask(tokens_list, vocab_size); +} + +void ValidPathFilter::init_cached_mask( + const std::vector>& tokens_list, + const int32_t vocab_size) { + size_t total_num = tokens_list.size(); + if (total_num > 0) { + init_cached_tokens_ = true; + } + + // init extra thread pool + thread_num_ = util::get_int_env(util::EXTRA_THREAD_NUM, 16); + extra_threadpool_ = std::make_unique(thread_num_); + + // generate mask + torch::TensorOptions options = torch::dtype(dtype_).device(device_); + first_token_mask_ = torch::full({vocab_size}, pre_mask_factor_, dtype_); + empty_place_holder_ = torch::full({vocab_size}, 0.0f, options); + + cached_sparse_mask_.reserve(total_num); + for (size_t t_idx = 0; t_idx < total_num; t_idx++) { + Slice tokens_slice(tokens_list[t_idx]); + CHECK_EQ(tokens_slice.size(), 3); + + // handle first token + first_token_mask_[tokens_slice[0]] = 0; + + // handle extra token + for (int i = 1; i < tokens_slice.size(); i++) { + Murmur3Key murmur3_key; + Slice sub_slice(tokens_slice.data(), i); + murmur_hash3(nullptr, sub_slice, murmur3_key.data); + auto iter = cached_sparse_mask_.find(murmur3_key); + if (iter != cached_sparse_mask_.end()) { + iter->second.push_back(tokens_slice[i]); + } else { + std::vector false_indices = {tokens_slice[i]}; + cached_sparse_mask_.emplace(std::make_pair(murmur3_key, false_indices)); + } + } + } + + // Remove duplicates and sort for better performance + // Sort false indices in sparse masks for better performance + for (auto& pair : cached_sparse_mask_) { + std::sort(pair.second.begin(), pair.second.end()); + pair.second.erase(std::unique(pair.second.begin(), pair.second.end()), + pair.second.end()); + } + // first_token_mask_ = safe_to(first_token_mask_, device_, true); + LOG(INFO) << " ValidPathFilter third sparse storage: " + << cached_sparse_mask_.size(); +} + +torch::Tensor ValidPathFilter::forward( + const std::vector>& tokens_list) { + if (!init_cached_tokens_ || tokens_list.size() == 0) { + return torch::Tensor(); + } + + size_t token_size = tokens_list[0].size(); + + // prepare mask for first token + if (token_size == 0) { + size_t total_nums = tokens_list.size(); + auto mask = first_token_mask_.unsqueeze(0); + return mask.repeat({total_nums, 1}); + } + return forward_sparse_mask(tokens_list); +} + +torch::Tensor ValidPathFilter::forward_sparse_mask( + const std::vector>& tokens_list) { + Timer timer; + size_t total_nums = tokens_list.size(); + torch::TensorOptions options = torch::dtype(dtype_).device(device_); + auto mask = torch::full({total_nums, vocab_size_}, pre_mask_factor_, options); + + // Global batch collection for sparse storage optimization + std::vector global_batch_token_indices; + std::vector global_batch_vocab_indices; + std::mutex batch_mutex; // Protect global batch vectors in multi-threading + + // Pre-allocate space: assume max 8192 false indices per token + global_batch_token_indices.reserve(8192 * total_nums); + global_batch_vocab_indices.reserve(8192 * total_nums); + + auto update_mask = [&](size_t start_idx, size_t end_idx) { + // Local collection for this thread + std::vector local_token_indices; + std::vector local_vocab_indices; + local_token_indices.reserve(8192 * (end_idx - start_idx)); + local_vocab_indices.reserve(8192 * (end_idx - start_idx)); + + for (size_t token_idx = start_idx; token_idx < end_idx; ++token_idx) { + auto& tokens = tokens_list[token_idx]; + if (tokens.size() == 0) { + mask[token_idx] = first_token_mask_.to(device_); + } else { + Slice tokens_slice(tokens); + Murmur3Key murmur3_key; + murmur_hash3(nullptr, tokens_slice, murmur3_key.data); + + auto iter = cached_sparse_mask_.find(murmur3_key); + if (iter != cached_sparse_mask_.end()) { + // Collect indices locally first + for (int32_t vocab_idx : iter->second) { + local_token_indices.push_back(static_cast(token_idx)); + local_vocab_indices.push_back(static_cast(vocab_idx)); + } + } else { + mask[token_idx] = empty_place_holder_; + LOG(ERROR) << "Failed to generate mask for " << tokens; + } + } + } + + // Merge local results to global batch (thread-safe) + if (!local_token_indices.empty()) { + std::lock_guard lock(batch_mutex); + global_batch_token_indices.insert(global_batch_token_indices.end(), + local_token_indices.begin(), + local_token_indices.end()); + global_batch_vocab_indices.insert(global_batch_vocab_indices.end(), + local_vocab_indices.begin(), + local_vocab_indices.end()); + } + }; + + if (use_threadpool_for_beam_expansion_) { + // 分段处理优化:每个线程处理多个mask + const size_t batch_size = + std::max(1UL, (total_nums + thread_num_ - 1) / thread_num_); + const size_t num_batches = (total_nums + batch_size - 1) / batch_size; + + std::vector>> promises; + std::vector> futures; + promises.reserve(num_batches); + futures.reserve(num_batches); + + for (size_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) { + auto promise = std::make_shared>(); + futures.push_back(promise->get_future()); + promises.push_back(promise); + + size_t start_idx = batch_idx * batch_size; + size_t end_idx = std::min(start_idx + batch_size, total_nums); + + extra_threadpool_->schedule( + [update_mask, start_idx, end_idx, promise]() mutable { + update_mask(start_idx, end_idx); + promise->set_value(); + }); + } + + for (auto& future : futures) { + future.get(); + } + } else { + update_mask(0, total_nums); + } + + // Global batch tensor operation after all threads complete + if (!global_batch_token_indices.empty()) { + auto token_indices = + torch::tensor(global_batch_token_indices, torch::kInt64); + auto vocab_indices = + torch::tensor(global_batch_vocab_indices, torch::kInt64); + torch::TensorOptions device_options = + torch::dtype(torch::kInt64).device(device_); + token_indices = safe_to(token_indices, device_options, true); + vocab_indices = safe_to(vocab_indices, device_options, true); + mask.index_put_({token_indices, vocab_indices}, 0.0f); + // auto indices = torch::stack({token_indices, vocab_indices}, 1); + // return indices; + } + + return mask; +} +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/sampling/valid_path_filter.h b/xllm/core/framework/sampling/valid_path_filter.h new file mode 100644 index 00000000..c3eda7e1 --- /dev/null +++ b/xllm/core/framework/sampling/valid_path_filter.h @@ -0,0 +1,65 @@ +#pragma once +#include +#include +#include + +#include "util/hash_util.h" +#include "util/threadpool.h" + +namespace xllm { + +class ValidPathFilter final { + public: + ValidPathFilter(const std::string valid_path_filter_file, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device); + ValidPathFilter(const std::vector>& tokens_list, + const int32_t vocab_size, + torch::ScalarType dtype, + torch::Device device); + + // operator() allows us to use the module as a function. + template + auto operator()(Args&&... args) const { + return this->forward(::std::forward(args)...); + } + + // output: [num_tokens, vocab_size] + torch::Tensor forward(const std::vector>& tokens_list); + + private: + void init_cached_mask(const std::vector>& tokens_list, + const int32_t vocab_size); + + // prepare mask using cached sparse mask + torch::Tensor forward_sparse_mask( + const std::vector>& tokens_list); + + // Sparse storage: map from key to indices of candidate tokens. + std::unordered_map, + FixedStringKeyHash, + FixedStringKeyEqual> + cached_sparse_mask_; + + torch::Tensor empty_place_holder_; + torch::Tensor first_token_mask_; + + bool init_cached_tokens_ = false; + + static float pre_mask_factor_; + + int32_t vocab_size_; + + torch::ScalarType dtype_ = torch::ScalarType::Undefined; + + torch::Device device_; + + int32_t thread_num_; + std::unique_ptr extra_threadpool_; + // 控制是否使用线程池进行beam expansion + bool use_threadpool_for_beam_expansion_ = true; +}; + +} // namespace xllm diff --git a/xllm/core/framework/sampling/valid_path_filter_test.cpp b/xllm/core/framework/sampling/valid_path_filter_test.cpp new file mode 100644 index 00000000..cb2e4f2a --- /dev/null +++ b/xllm/core/framework/sampling/valid_path_filter_test.cpp @@ -0,0 +1,167 @@ +#include "valid_path_filter.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace xllm { + +TEST(ValidPathFilterTest, Vector) { + // 基于实际使用场景的测试数据 + // tokens_list表示有效的token序列路径,每个序列长度为3 + std::vector> tokens_list = { + {1, 2, 3}, // 序列1: 1->2->3 + {1, 2, 4}, // 序列2: 1->2->4 + {1, 3, 5}, // 序列3: 1->3->5 + {2, 4, 6}, // 序列4: 2->4->6 + {3, 5, 7} // 序列5: 3->5->7 + }; + + torch::ScalarType dtype(torch::kFloat32); + torch::Device device(torch::kCPU); + int32_t vocab_size = 8; // 词汇表大小为8 (tokens 0-7) + + ValidPathFilter filter = + ValidPathFilter(tokens_list, vocab_size, dtype, device); + + // 测试不同的候选token序列 + std::vector> candidate_tokens = { + {1, 2}, // 前缀[1,2],应该允许token 3和4 + {1}, // 前缀[1],应该允许token 2和3 + {}, // 空前缀,应该允许第一个token 1,2,3 + {2, 4}, // 前缀[2,4],应该允许token 6 + {9, 9} // 无效前缀,应该全部被mask + }; + + const auto options = torch::dtype(dtype).device(device); + torch::Tensor mask = filter.forward(candidate_tokens); + + // 验证输出形状 + EXPECT_EQ(mask.sizes(), + torch::IntArrayRef({candidate_tokens.size(), vocab_size})); + + // 验证mask值 + // mask值为0表示允许,-10000表示禁止 + + // 对于前缀[1,2]:下一个token可以是3或4 + auto mask_1_2 = mask[0]; + EXPECT_EQ(mask_1_2[3].item(), 0.0f); // token 3允许 + EXPECT_EQ(mask_1_2[4].item(), 0.0f); // token 4允许 + EXPECT_EQ(mask_1_2[0].item(), -10000.0f); // token 0禁止 + EXPECT_EQ(mask_1_2[1].item(), -10000.0f); // token 1禁止 + EXPECT_EQ(mask_1_2[2].item(), -10000.0f); // token 2禁止 + + // 对于前缀[1]:下一个token可以是2或3 + auto mask_1 = mask[1]; + EXPECT_EQ(mask_1[2].item(), 0.0f); // token 2允许 + EXPECT_EQ(mask_1[3].item(), 0.0f); // token 3允许 + EXPECT_EQ(mask_1[0].item(), -10000.0f); // token 0禁止 + EXPECT_EQ(mask_1[1].item(), -10000.0f); // token 1禁止 + + // 对于空前缀[]:第一个token可以是1,2,3 + auto mask_empty = mask[2]; + EXPECT_EQ(mask_empty[1].item(), 0.0f); // token 1允许 + EXPECT_EQ(mask_empty[2].item(), 0.0f); // token 2允许 + EXPECT_EQ(mask_empty[3].item(), 0.0f); // token 3允许 + EXPECT_EQ(mask_empty[0].item(), -10000.0f); // token 0禁止 +} + +TEST(ValidPathFilterTest, File) { + // 创建测试数据文件 + std::vector> tokens_list = { + {1, 2, 3}, {1, 2, 4}, {1, 3, 5}, {2, 4, 6}, {3, 5, 7}}; + + const std::string rec_tokens_file = "./test_data.bin"; + + // 清理旧文件 + if (std::ifstream(rec_tokens_file)) { + std::remove(rec_tokens_file.c_str()); + } + + // 按照实现期望的格式写入文件:int64_t item_id + 3个int32_t + std::ofstream outfile(rec_tokens_file, std::ios::binary); + if (!outfile) { + LOG(ERROR) << "Failed to create test file: " << rec_tokens_file; + return; + } + + int64_t item_id = 0; + for (const auto& row : tokens_list) { + outfile.write(reinterpret_cast(&item_id), sizeof(int64_t)); + outfile.write(reinterpret_cast(row.data()), + row.size() * sizeof(int32_t)); + item_id++; + } + outfile.close(); + + torch::ScalarType dtype(torch::kFloat32); + torch::Device device(torch::kCPU); + int32_t vocab_size = 8; + + ValidPathFilter filter = + ValidPathFilter(rec_tokens_file, vocab_size, dtype, device); + + // 使用相同的测试用例 + std::vector> candidate_tokens = { + {1, 2}, // 前缀[1,2] + {1}, // 前缀[1] + {} // 空前缀 + }; + + const auto options = torch::dtype(dtype).device(device); + torch::Tensor mask = filter.forward(candidate_tokens); + + // 验证输出形状 + EXPECT_EQ(mask.sizes(), + torch::IntArrayRef({candidate_tokens.size(), vocab_size})); + + // 验证与Vector测试相同的结果 + // 对于前缀[1,2]:下一个token可以是3或4 + auto mask_1_2 = mask[0]; + EXPECT_EQ(mask_1_2[3].item(), 0.0f); + EXPECT_EQ(mask_1_2[4].item(), 0.0f); + + // 对于前缀[1]:下一个token可以是2或3 + auto mask_1 = mask[1]; + EXPECT_EQ(mask_1[2].item(), 0.0f); + EXPECT_EQ(mask_1[3].item(), 0.0f); + + // 对于空前缀[]:第一个token可以是1,2,3 + auto mask_empty = mask[2]; + EXPECT_EQ(mask_empty[1].item(), 0.0f); + EXPECT_EQ(mask_empty[2].item(), 0.0f); + EXPECT_EQ(mask_empty[3].item(), 0.0f); + + // 清理测试文件 + if (std::ifstream(rec_tokens_file)) { + std::remove(rec_tokens_file.c_str()); + } +} + +TEST(ValidPathFilterTest, EmptyInput) { + // 测试空输入的情况 + std::vector> tokens_list = {{1, 2, 3}}; + torch::ScalarType dtype(torch::kFloat32); + torch::Device device(torch::kCPU); + int32_t vocab_size = 5; + + ValidPathFilter filter = + ValidPathFilter(tokens_list, vocab_size, dtype, device); + + // 测试空的候选token列表 + std::vector> empty_candidates = {}; + torch::Tensor mask = filter.forward(empty_candidates); + + // 应该返回未定义的tensor + EXPECT_FALSE(mask.defined()); +} + +} // namespace xllm diff --git a/xllm/core/layers/npu/CMakeLists.txt b/xllm/core/layers/npu/CMakeLists.txt index 61f7759d..93d48935 100644 --- a/xllm/core/layers/npu/CMakeLists.txt +++ b/xllm/core/layers/npu/CMakeLists.txt @@ -25,6 +25,7 @@ cc_library( npu_qwen3_decoder_layer_impl.h npu_rms_norm_impl.h npu_siglip_encoder_layer_impl.h + npu_onerec_block_layer_impl.h SRCS npu_word_embedding_impl.cpp npu_pos_embedding_impl.cpp @@ -45,6 +46,7 @@ cc_library( npu_qwen3_decoder_layer_impl.cpp npu_rms_norm_impl.cpp npu_siglip_encoder_layer_impl.cpp + npu_onerec_block_layer_impl.cpp DEPS "-Wl,--whole-archive" "-Wl,--no-whole-archive" diff --git a/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp b/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp new file mode 100644 index 00000000..01024706 --- /dev/null +++ b/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp @@ -0,0 +1,1880 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "npu_onerec_block_layer_impl.h" + +#include +#include + +#include +#include + +#include "common/global_flags.h" +#include "core/layers/attention_mask.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUException.h" + +namespace xllm { +namespace layer { +// Decoder normal mode: self-attn(29) + cross-attn(28) + layer-norm(4) + mlp(18) +// = 79 +const uint64_t ONEREC_WEIGHT_COUNT_PER_LAYER = 79; +// Decoder normal mode: self-attn(29) + cross-attn(28) + layer-norm(4) + mlp(18) +// + data(4) = 83 +const uint64_t ONEREC_MOE_WEIGHT_COUNT_PER_LAYER = + 97; // Decoder only with MoE (61-100: 40 weight tensors) +// + 2 + 2 + 2 +enum ONERECBlockLayerTensorId : int { + // Self-attention layer norm + IN_LAYER_NORM_WEIGHT = 0, + IN_LAYER_NORM_BIAS, + IN_INPUT_NORM_NEW_WEIGHT, + IN_INPUT_NORM_NEW_BIAS, + // Self-attention Q, K, V projections + IN_Q_WEIGHT, + IN_Q_BIAS, + IN_Q_DEQSCALE, + IN_Q_OFFSET, + IN_Q_SCALE, + IN_Q_COMPRESS_IDX, + + IN_K_WEIGHT, + IN_K_BIAS, + IN_K_DEQSCALE, + IN_K_OFFSET, + IN_K_SCALE, + IN_K_COMPRESS_IDX, + + IN_V_WEIGHT, + IN_V_BIAS, + IN_V_DEQSCALE, + IN_V_OFFSET, + IN_V_SCALE, + IN_V_COMPRESS_IDX, + + // Self-attention output projection + IN_SELF_ATTN_OUT_WEIGHT, + IN_SELF_ATTN_OUT_BIAS, + IN_SELF_ATTN_OUT_DEQSCALE, + IN_SELF_ATTN_OUT_OFFSET, + IN_SELF_ATTN_OUT_SCALE, + IN_SELF_ATTN_OUT_COMPRESS_IDX, + + // ONEREC relative attention bias (encoder only) + IN_RELATIVE_ATTENTION_BIAS_WEIGHT, + + // Cross-attention layer norm (decoder only) + IN_CROSS_LAYER_NORM_WEIGHT, + IN_CROSS_LAYER_NORM_BIAS, + IN_CROSS_LAYER_NORM_NEW_WEIGHT, + IN_CROSS_LAYER_NORM_NEW_BIAS, + + // Cross-attention Q, K, V projections (decoder only) + IN_CROSS_Q_WEIGHT, + IN_CROSS_Q_BIAS, + IN_CROSS_Q_DEQSCALE, + IN_CROSS_Q_OFFSET, + IN_CROSS_Q_SCALE, + IN_CROSS_Q_COMPRESS_IDX, + + IN_CROSS_K_WEIGHT, + IN_CROSS_K_BIAS, + IN_CROSS_K_DEQSCALE, + IN_CROSS_K_OFFSET, + IN_CROSS_K_SCALE, + IN_CROSS_K_COMPRESS_IDX, + + IN_CROSS_V_WEIGHT, + IN_CROSS_V_BIAS, + IN_CROSS_V_DEQSCALE, + IN_CROSS_V_OFFSET, + IN_CROSS_V_SCALE, + IN_CROSS_V_COMPRESS_IDX, + + // Cross-attention output projection (decoder only) + IN_CROSS_ATTN_OUT_WEIGHT, + IN_CROSS_ATTN_OUT_BIAS, + IN_CROSS_ATTN_OUT_DEQSCALE, + IN_CROSS_ATTN_OUT_OFFSET, + IN_CROSS_ATTN_OUT_SCALE, + IN_CROSS_ATTN_OUT_COMPRESS_IDX, + + // Final layer norm + IN_FINAL_LAYER_NORM_WEIGHT, + IN_FINAL_LAYER_NORM_BIAS, + IN_FINAL_LAYER_NORM_NEW_WEIGHT, + IN_FINAL_LAYER_NORM_NEW_BIAS, + + // Feed-forward network (gated activation) + IN_FFN_WI_0_WEIGHT = 61, // wi_0 (gate projection) + IN_FFN_WI_0_BIAS, + IN_FFN_WI_0_DEQSCALE, + IN_FFN_WI_0_OFFSET, + IN_FFN_WI_0_SCALE, + IN_FFN_WI_0_COMPRESS_IDX, + + IN_FFN_WI_1_WEIGHT, // wi_1 (up projection) + IN_FFN_WI_1_BIAS, + IN_FFN_WI_1_DEQSCALE, + IN_FFN_WI_1_OFFSET, + IN_FFN_WI_1_SCALE, + IN_FFN_WI_1_COMPRESS_IDX, + + IN_FFN_WO_WEIGHT, // wo (down projection) + IN_FFN_WO_BIAS, + IN_FFN_WO_DEQSCALE, + IN_FFN_WO_OFFSET, + IN_FFN_WO_SCALE, + IN_FFN_WO_COMPRESS_IDX, +}; + +enum ONERECMoeBlockLayerTensorId : int { + // MoE weights (only used when use_moe=true) - Updated to match kernel layer + // names + IN_BLOCK_SPARSE_MOE_GATE_WEIGHT = 61, // Gate/routing weights + IN_BLOCK_SPARSE_MOE_GATE_BIAS = 62, // Gate bias + IN_BLOCK_SPARSE_MOE_GATE_DESCALE, // Gate descale + IN_BLOCK_SPARSE_MOE_GATE_OFFSET, // Gate offset + IN_BLOCK_SPARSE_MOE_GATE_SCALE, // Gate scale + IN_BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX, // Gate compress index + + // Shared Expert weights + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, // Shared expert gateup weights (merged + // gate+up projection) + IN_MLP_GATEUP_BIAS_SHARED_EXPERT, // Shared expert gateup bias + IN_MLP_GATEUP_DESCALE_SHARED_EXPERT, // Shared expert gateup descale + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, // Shared expert gateup offset + IN_MLP_GATEUP_SCALE_SHARED_EXPERT, // Shared expert gateup scale + IN_MLP_GATEUP_COMPRESS_IDX_SHARED_EXPERT, // Shared expert gateup compress + // index + + IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, // Shared expert down projection weights + IN_MLP_DOWN_BIAS_SHARED_EXPERT, // Shared expert down bias + IN_MLP_DOWN_DESCALE_SHARED_EXPERT, // Shared expert down descale + IN_MLP_DOWN_OFFSET_SHARED_EXPERT, // Shared expert down offset + IN_MLP_DOWN_SCALE_SHARED_EXPERT, // Shared expert down scale + IN_MLP_DOWN_COMPRESS_IDX_SHARED_EXPERT, // Shared expert down compress index + + IN_SHARED_EXPERT_GATE_WEIGHT, // Shared expert gate weight + IN_SHARED_EXPERT_GATE_BIAS, // Shared expert gate bias + IN_SHARED_EXPERT_GATE_DESCALE, // Shared expert gate descale + IN_SHARED_EXPERT_GATE_OFFSET, // Shared expert gate offset + IN_SHARED_EXPERT_GATE_SCALE, // Shared expert gate scale + IN_SHARED_EXPERT_GATE_COMPRESS_IDX, // Shared expert gate compress index + + IN_MLP_GATEUP_WEIGHT_EXPERT, // Expert gateup weights (merged gate+up + // projection) + IN_MLP_GATEUP_BIAS_EXPERT, // Expert gateup bias + IN_MLP_GATEUP_DESCALE_EXPERT, // Expert gateup descale + IN_MLP_GATEUP_OFFSET_EXPERT, // Expert gateup offset + IN_MLP_GATEUP_SCALE_EXPERT, // Expert gateup scale + IN_MLP_GATEUP_COMPRESS_IDX_EXPERT, // Expert gateup compress index + + IN_MLP_DOWN_WEIGHT_EXPERT, // Expert down projection weights + IN_MLP_DOWN_BIAS_EXPERT, // Expert down bias + IN_MLP_DOWN_DESCALE_EXPERT, // Expert down descale + IN_MLP_DOWN_OFFSET_EXPERT, // Expert down offset + IN_MLP_DOWN_SCALE_EXPERT, // Expert down scale + IN_MLP_DOWN_COMPRESS_IDX_EXPERT = 96, // Expert down compress index + + IN_EXPERT_ARRAY = 97, // Expert array tensor + IN_EXPERT_GROUP = 98, // Expert group tensor + IN_ONE_HOT = 99, // One hot tensor + IN_ZERO_HOT = 100, // Zero hot tensor + + // Legacy aliases for backward compatibility + IN_MOE_EXPERT_W1_WEIGHT = IN_MLP_GATEUP_WEIGHT_EXPERT, + IN_MOE_EXPERT_W2_WEIGHT = IN_MLP_DOWN_WEIGHT_EXPERT, + IN_MOE_EXPERT_W3_WEIGHT = + IN_MLP_GATEUP_WEIGHT_EXPERT, // Same as W1 for gate+up merged + IN_MOE_SHARED_W1_WEIGHT = IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, + IN_MOE_SHARED_W2_WEIGHT = IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, +}; + +// ONEREC encoder weight mapping - Updated to match actual weight file format +static const std::unordered_map + ONEREC_ENCODER_WEIGHT_MAPPING = { + // Primary mappings - match actual weight file format with full paths + {"layer.0.layer_norm.weight", IN_LAYER_NORM_WEIGHT}, + {"layer.0.SelfAttention.q.weight", IN_Q_WEIGHT}, + {"layer.0.SelfAttention.k.weight", IN_K_WEIGHT}, + {"layer.0.SelfAttention.v.weight", IN_V_WEIGHT}, + {"layer.0.SelfAttention.o.weight", IN_SELF_ATTN_OUT_WEIGHT}, + {"layer.0.SelfAttention.relative_attention_bias.weight", + IN_RELATIVE_ATTENTION_BIAS_WEIGHT}, + {"layer.1.layer_norm.weight", IN_FINAL_LAYER_NORM_WEIGHT}, + {"layer.1.DenseReluDense.wi.weight", IN_FFN_WI_1_WEIGHT}, + {"layer.1.DenseReluDense.wo.weight", IN_FFN_WO_WEIGHT}, + {"layer.1.DenseReluDense.gate_proj.weight", IN_FFN_WI_0_WEIGHT}, + {"layer.1.ffn.wi.weight", IN_FFN_WI_1_WEIGHT}, + {"layer.1.ffn.wo.weight", IN_FFN_WO_WEIGHT}, + {"layer.1.ffn.gate_proj.weight", IN_FFN_WI_0_WEIGHT}, + // Alternative mappings for different weight file formats + {"0.layer_norm.weight", IN_LAYER_NORM_WEIGHT}, + {"0.SelfAttention.q.weight", IN_Q_WEIGHT}, + {"0.SelfAttention.k.weight", IN_K_WEIGHT}, + {"0.SelfAttention.v.weight", IN_V_WEIGHT}, + {"0.SelfAttention.o.weight", IN_SELF_ATTN_OUT_WEIGHT}, + {"0.SelfAttention.relative_attention_bias.weight", + IN_RELATIVE_ATTENTION_BIAS_WEIGHT}, + {"1.layer_norm.weight", IN_FINAL_LAYER_NORM_WEIGHT}, + {"1.DenseReluDense.wi.weight", IN_FFN_WI_1_WEIGHT}, + {"1.DenseReluDense.wo.weight", IN_FFN_WO_WEIGHT}, + {"1.DenseReluDense.gate_proj.weight", IN_FFN_WI_0_WEIGHT}, + {"1.ffn.wi.weight", IN_FFN_WI_1_WEIGHT}, + {"1.ffn.wo.weight", IN_FFN_WO_WEIGHT}, + {"1.ffn.gate_proj.weight", IN_FFN_WI_0_WEIGHT}, +}; + +// ONEREC decoder weight mapping - Updated to match actual weight file format +static const std::unordered_map + ONEREC_DECODER_WEIGHT_MAPPING = { + // Primary mappings - match actual weight file format with full paths + {"layer.0.layer_norm.weight", IN_LAYER_NORM_WEIGHT}, + {"layer.0.SelfAttention.q.weight", IN_Q_WEIGHT}, + {"layer.0.SelfAttention.k.weight", IN_K_WEIGHT}, + {"layer.0.SelfAttention.v.weight", IN_V_WEIGHT}, + {"layer.0.SelfAttention.o.weight", IN_SELF_ATTN_OUT_WEIGHT}, + {"layer.0.SelfAttention.relative_attention_bias.weight", + IN_RELATIVE_ATTENTION_BIAS_WEIGHT}, + {"layer.1.layer_norm.weight", IN_CROSS_LAYER_NORM_WEIGHT}, + {"layer.1.EncDecAttention.q.weight", IN_CROSS_Q_WEIGHT}, + {"layer.1.EncDecAttention.k.weight", IN_CROSS_K_WEIGHT}, + {"layer.1.EncDecAttention.v.weight", IN_CROSS_V_WEIGHT}, + {"layer.1.EncDecAttention.o.weight", IN_CROSS_ATTN_OUT_WEIGHT}, + {"layer.2.layer_norm.weight", IN_FINAL_LAYER_NORM_WEIGHT}, + {"layer.2.DenseReluDense.wi.weight", IN_FFN_WI_1_WEIGHT}, + {"layer.2.DenseReluDense.wo.weight", IN_FFN_WO_WEIGHT}, + {"layer.2.DenseReluDense.gate_proj.weight", IN_FFN_WI_0_WEIGHT}, + {"layer.2.ffn.wi.weight", IN_FFN_WI_1_WEIGHT}, + {"layer.2.ffn.wo.weight", IN_FFN_WO_WEIGHT}, + {"layer.2.ffn.gate_proj.weight", IN_FFN_WI_0_WEIGHT}, + // Alternative mappings for different weight file formats + {"0.layer_norm.weight", IN_LAYER_NORM_WEIGHT}, + {"0.SelfAttention.q.weight", IN_Q_WEIGHT}, + {"0.SelfAttention.k.weight", IN_K_WEIGHT}, + {"0.SelfAttention.v.weight", IN_V_WEIGHT}, + {"0.SelfAttention.o.weight", IN_SELF_ATTN_OUT_WEIGHT}, + {"0.SelfAttention.relative_attention_bias.weight", + IN_RELATIVE_ATTENTION_BIAS_WEIGHT}, + {"1.layer_norm.weight", IN_CROSS_LAYER_NORM_WEIGHT}, + {"1.EncDecAttention.q.weight", IN_CROSS_Q_WEIGHT}, + {"1.EncDecAttention.k.weight", IN_CROSS_K_WEIGHT}, + {"1.EncDecAttention.v.weight", IN_CROSS_V_WEIGHT}, + {"1.EncDecAttention.o.weight", IN_CROSS_ATTN_OUT_WEIGHT}, + {"2.layer_norm.weight", IN_FINAL_LAYER_NORM_WEIGHT}, + {"2.DenseReluDense.wi.weight", IN_FFN_WI_1_WEIGHT}, + {"2.DenseReluDense.wo.weight", IN_FFN_WO_WEIGHT}, + {"2.DenseReluDense.gate_proj.weight", IN_FFN_WI_0_WEIGHT}, + {"2.ffn.wi.weight", IN_FFN_WI_1_WEIGHT}, + {"2.ffn.wo.weight", IN_FFN_WO_WEIGHT}, + {"2.ffn.gate_proj.weight", IN_FFN_WI_0_WEIGHT}, +}; + +// ONEREC MoE weight mapping for decoder +// MoE weight mapping function - handles individual expert weights +static std::unordered_map +get_onerec_decoder_moe_weight_mapping() { + std::unordered_map mapping = { + // Self-attention layer norm + {"layer.0.layer_norm.weight", IN_LAYER_NORM_WEIGHT}, + {"layer.0.SelfAttention.q.weight", IN_Q_WEIGHT}, + {"layer.0.SelfAttention.k.weight", IN_K_WEIGHT}, + {"layer.0.SelfAttention.v.weight", IN_V_WEIGHT}, + {"layer.0.SelfAttention.o.weight", IN_SELF_ATTN_OUT_WEIGHT}, + {"layer.0.SelfAttention.relative_attention_bias.weight", + IN_RELATIVE_ATTENTION_BIAS_WEIGHT}, + // Cross-attention layer norm + {"layer.1.layer_norm.weight", IN_CROSS_LAYER_NORM_WEIGHT}, + {"layer.1.EncDecAttention.q.weight", IN_CROSS_Q_WEIGHT}, + {"layer.1.EncDecAttention.k.weight", IN_CROSS_K_WEIGHT}, + {"layer.1.EncDecAttention.v.weight", IN_CROSS_V_WEIGHT}, + {"layer.1.EncDecAttention.o.weight", IN_CROSS_ATTN_OUT_WEIGHT}, + {"layer.2.layer_norm.weight", IN_FINAL_LAYER_NORM_WEIGHT}, + + // Alternative naming patterns + {"0.layer_norm.weight", IN_LAYER_NORM_WEIGHT}, + {"0.SelfAttention.q.weight", IN_Q_WEIGHT}, + {"0.SelfAttention.k.weight", IN_K_WEIGHT}, + {"0.SelfAttention.v.weight", IN_V_WEIGHT}, + {"0.SelfAttention.o.weight", IN_SELF_ATTN_OUT_WEIGHT}, + {"0.SelfAttention.relative_attention_bias.weight", + IN_RELATIVE_ATTENTION_BIAS_WEIGHT}, + {"1.layer_norm.weight", IN_CROSS_LAYER_NORM_WEIGHT}, + {"1.EncDecAttention.q.weight", IN_CROSS_Q_WEIGHT}, + {"1.EncDecAttention.k.weight", IN_CROSS_K_WEIGHT}, + {"1.EncDecAttention.v.weight", IN_CROSS_V_WEIGHT}, + {"1.EncDecAttention.o.weight", IN_CROSS_ATTN_OUT_WEIGHT}, + {"2.layer_norm.weight", IN_FINAL_LAYER_NORM_WEIGHT}, + // MoE gate weight + {"layer.2.ffn.gate.weight", IN_BLOCK_SPARSE_MOE_GATE_WEIGHT}, + {"2.ffn.gate.weight", IN_BLOCK_SPARSE_MOE_GATE_WEIGHT}, + + // Shared Expert weight mappings (using w1/w2/w3 naming) + {"layer.2.ffn.shared_experts.w1.weight", + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"layer.2.ffn.shared_experts.w3.weight", + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"layer.2.ffn.shared_experts.w2.weight", + IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, + + // Shared Expert gate weight mappings + {"layer.2.ffn.shared_expert.gate.weight", IN_SHARED_EXPERT_GATE_WEIGHT}, + {"layer.2.ffn.shared_expert.gate.bias", IN_SHARED_EXPERT_GATE_BIAS}, + {"layer.2.ffn.shared_expert.gate.weight_scale", + IN_SHARED_EXPERT_GATE_SCALE}, + {"layer.2.ffn.shared_expert.gate.weight_offset", + IN_SHARED_EXPERT_GATE_OFFSET}, + + // Expert weight mappings (without expert index - processed by + // extract_expert_index) + // Gate projection weights (w1) - merged with up in gateup + {"w1.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, + // Up projection weights (w3) - merged with gate in gateup + {"w3.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, + // Down projection weights (w2) + {"w2.weight", IN_MLP_DOWN_WEIGHT_EXPERT}, + }; + + return mapping; +} + +static const std::unordered_map + ONEREC_DECODER_MOE_WEIGHT_MAPPING = get_onerec_decoder_moe_weight_mapping(); + +// ONEREC MoE weight mapping for encoder +// ONEREC_ENCODER_MOE_WEIGHT_MAPPING removed - use_moe only supports decoder +// mode + +static std::map ONEREC_WEIGHT_SHARD = { + {IN_Q_WEIGHT, 0}, + {IN_K_WEIGHT, 0}, + {IN_V_WEIGHT, 0}, + {IN_SELF_ATTN_OUT_WEIGHT, 1}, + {IN_CROSS_Q_WEIGHT, 0}, + {IN_CROSS_K_WEIGHT, 0}, + {IN_CROSS_V_WEIGHT, 0}, + {IN_CROSS_ATTN_OUT_WEIGHT, 1}, + {IN_FFN_WI_0_WEIGHT, 0}, + {IN_FFN_WI_1_WEIGHT, 0}, + {IN_FFN_WO_WEIGHT, 1}, + // MoE weights + {IN_BLOCK_SPARSE_MOE_GATE_WEIGHT, 0}, + {IN_MLP_GATEUP_WEIGHT_EXPERT, 0}, + {IN_MLP_DOWN_WEIGHT_EXPERT, 1}, + // Shared Expert weights + {IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, 0}, + {IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, 0}, + {IN_MLP_GATEUP_SCALE_SHARED_EXPERT, 0}, + {IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, 1}, + {IN_MLP_DOWN_OFFSET_SHARED_EXPERT, 1}, + {IN_MLP_DOWN_SCALE_SHARED_EXPERT, 1}, + {IN_SHARED_EXPERT_GATE_WEIGHT, 0}, + {IN_SHARED_EXPERT_GATE_BIAS, 0}, + {IN_SHARED_EXPERT_GATE_SCALE, 0}, + {IN_SHARED_EXPERT_GATE_OFFSET, 0}}; + +NpuOneRecBlockLayerImpl::NpuOneRecBlockLayerImpl(const ModelContext& context, + bool is_decoder, + int layer_id) + : NpuBaseLayer(context), is_decoder_(is_decoder), layer_id_(layer_id) { + // LOG(INFO) << "ONERECBlockLayerImpl constructor: " << layer_id_ << ":" + // << is_decoder_; + param_from_args( + prefill_param_, context.get_model_args(), parallel_args_, true); + prefill_param_.isDecoder = is_decoder; + // param_from_args(decode_param_, args, parallel_args, false); + // decode_param_.isDecoder = is_decoder; + + // Initialize decoder_prefill_only_decode_param_ if enable_onerec_prefill_only + // is true if (FLAGS_enable_rec_prefill_only && is_decoder) { + // param_from_args( + // decoder_prefill_only_decode_param_, args, parallel_args, true); + // decoder_prefill_only_decode_param_.isDecoder = is_decoder; + // decoder_prefill_only_decode_param_.emptyCrossAttn = false; + // } + // Choose correct weight count based on use_moe + int weight_count = prefill_param_.use_moe ? ONEREC_MOE_WEIGHT_COUNT_PER_LAYER + : ONEREC_WEIGHT_COUNT_PER_LAYER; + at_weight_tensors_.resize(weight_count); + atb_weight_tensors_.resize(weight_count); + // Initialize placeholder_vec_ with proper dimensions for ONEREC operations + // Some ATB operations may require specific tensor dimensions + placeholder_vec_ = {1, 1}; // 2D placeholder for better compatibility + dtype_ = c10::typeMetaToScalarType(context.get_tensor_options().dtype()); + device_id_ = context.get_tensor_options().device().index(); + + // Create placeholder tensors with proper dimensions for ONEREC operations + auto placeholder_tensor = torch::empty({1, 1}, torch::kInt32).to(device_); + placeholder = atb_speed::Utils::AtTensor2Tensor(placeholder_tensor); + at_placeholder = + torch::empty({1, context.get_model_args().hidden_size()}, dtype_) + .to(device_); + + for (int i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = + torch::zeros({1, context.get_model_args().hidden_size()}) + .to(context.get_tensor_options()); + } + + // Initialize MoE routing tensors if MoE is enabled + if (prefill_param_.use_moe) { + auto device = context.get_tensor_options().device(); + one_hot_ = torch::tensor({1}, torch::kInt32).to(device); + zero_hot_ = torch::tensor({0}, torch::kInt32).to(device); + expert_group_ = torch::tensor({1}, torch::dtype(torch::kInt32)).to(device); + } +} + +void NpuOneRecBlockLayerImpl::param_from_args( + atb_speed::onerec::BlockLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args, + bool isPrefill, + const ModelInputParams* input_params) { + LOG(INFO) << "begin param_from_args"; + param.isFA = false; // need page + param.isPrefill = isPrefill; + param.isBF16 = args.dtype() == "bfloat16"; + param.isPack = true; + param.supportSwiGLU = true; // ONEREC now uses gated activation by default + param.supportLcoc = isPrefill; + param.supportSpeculate = false; + param.enableSplitFuse = FLAGS_enable_chunked_prefill && isPrefill; + param.supportLora = false; + param.loraEnableGMM = false; + param.enableLogN = false; + param.kvQuant = false; + param.enableIntraLayerAddNorm = false; + param.enableInterLayerAddNorm = false; + // ONEREC position bias is now passed through attention_mask with ALIBI mask + // type hasPositionBias parameter is no longer needed + param.isDecoder = is_decoder_; + param.isOneRecEncoder = !is_decoder_; // ONEREC encoder uses bidirectional + // attention, no KV cache needed + param.enableOneRecPrefillOnly = FLAGS_enable_rec_prefill_only; + param.backend = "lccl"; + param.rank = parallel_args.rank(); + param.worldSize = parallel_args.world_size(); + param.quantType = 0; + param.quantGroupSize = 64; + auto args_n_heads = is_decoder_ ? args.decoder_n_heads() : args.n_heads(); + auto args_head_dim = is_decoder_ ? args.decoder_head_dim() : args.head_dim(); + param.numAttentionHeadsPerRank = args_n_heads / param.worldSize; + param.hiddenSizePerAttentionHead = args_head_dim; + LOG(INFO) << "hiddenSizePerAttentionHead: " + << param.hiddenSizePerAttentionHead; + LOG(INFO) << "numAttentionHeadsPerRank: " << param.numAttentionHeadsPerRank; + std::optional optionalValue = + is_decoder_ ? args.decoder_n_kv_heads().value_or(args.decoder_n_heads()) + : args.n_kv_heads().value_or(args.n_heads()); + param.numKeyValueHeadsPerRank = + static_cast(optionalValue.value()) / param.worldSize; + param.rmsNormEps = args.rms_norm_eps(); + param.seqLen = {}; + param.tokenOffset = {}; + param.packQuantType = {1, 1}; + param.linearQuantType = {0, -1, -1, 0, 0, -1, 0}; + param.layerId = layer_id_; + + // Set ModelInputParams for ONEREC model support + // param.inputParams = input_params; + // param.linearTransposeType = {1, -1, -1, 1, 1, -1, 1}; + param.linearTransposeType = {1, 1, 1, 1, 1, 1, 1}; + // Initialize linearDescs to enable QKV projection + // Elements: qkv, dense, gateup, down linear descriptions + if (param.isBF16) { + param.linearDescs = { + static_cast(atb_speed::common::LinearDesc::BFLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::BFLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::BFLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::BFLOAT16_DESC)}; + } else { + param.linearDescs = { + static_cast(atb_speed::common::LinearDesc::FLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::FLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::FLOAT16_DESC), + static_cast(atb_speed::common::LinearDesc::FLOAT16_DESC)}; + } + + // Set use_moe parameter from ModelArgs + param.use_moe = args.use_moe() && is_decoder_; + if (param.use_moe) { + // Initialize MoE parallel configuration (similar to Qwen3 and DeepSeek V2) + ep_size_ = 1; + auto ep_rank = 0; + ep_local_tp_size_ = parallel_args.world_size() / ep_size_; + CHECK_EQ(parallel_args.world_size(), ep_size_ * ep_local_tp_size_); + + num_experts_per_partition_ = args.n_routed_experts() / ep_size_; + start_expert_id_ = ep_rank * num_experts_per_partition_; + end_expert_id_ = start_expert_id_ + num_experts_per_partition_ - 1; + + // Initialize experts weights storage with 2D structure + resize_experts_weights(num_experts_per_partition_); + + // Configure OneRecMoEConfig + param.moe_config = std::make_unique(); + param.moe_config->moe_topk = args.num_experts_per_tok(); + param.moe_config->moe_num_experts = args.n_routed_experts(); + param.moe_config->moe_score_func = "softmax"; + param.moe_config->moe_route_scale = args.moe_route_scale(); + param.moe_config->moe_inter_dim = args.moe_intermediate_size(); + param.moe_config->use_bf16 = param.isBF16; + param.moe_config->hasSharedExpertGate = false; + param.moe_config->moe_use_shared_experts = args.moe_use_shared_experts(); + param.moe_config->moe_num_shared_experts = args.n_shared_experts(); + + // Initialize moeLinearQuantType for MoE layers + // Four components: ROUTER_IDX, MOE_MLP_GATE_IDX, MOE_MLP_UP_IDX, + // MOE_MLP_DOWN_IDX + param.moeLinearQuantType = { + atb_speed::common::LinearType::FP, // ROUTER_IDX (0) + atb_speed::common::LinearType::FP, // MOE_MLP_GATE_IDX (1) + atb_speed::common::LinearType::INVALID, // MOE_MLP_UP_IDX (2) + atb_speed::common::LinearType::FP // MOE_MLP_DOWN_IDX (3) + }; + } +} + +void NpuOneRecBlockLayerImpl::verify_loaded_weights( + const std::string& prefix) const { + // Choose appropriate weight mapping based on use_moe flag + const auto& weight_mapping = + [this]() -> const std::unordered_map& { + if (prefill_param_.use_moe) { + // If MoE is enabled, apply filtering based on configuration + if (prefill_param_.moe_config) { + static std::unordered_map filtered_mapping; + filtered_mapping.clear(); + + // Copy weights from the full MoE mapping based on configuration + for (const auto& [name, index] : ONEREC_DECODER_MOE_WEIGHT_MAPPING) { + bool should_include = true; + + // Filter shared expert weights based on moe_use_shared_experts + if (!prefill_param_.moe_config->moe_use_shared_experts) { + if (name.find("shared_expert") != std::string::npos) { + should_include = false; + } + } + + // Further filter shared expert gate weights based on + // hasSharedExpertGate + if (should_include && + !prefill_param_.moe_config->hasSharedExpertGate) { + if (name.find("shared_expert_gate") != std::string::npos) { + should_include = false; + } + } + + if (should_include) { + filtered_mapping[name] = index; + } + } + return filtered_mapping; + } else { + return ONEREC_DECODER_MOE_WEIGHT_MAPPING; + } + } else { + return is_decoder_ ? ONEREC_DECODER_WEIGHT_MAPPING + : ONEREC_ENCODER_WEIGHT_MAPPING; + } + }(); + + const uint64_t expected_weight_count = [this]() -> uint64_t { + if (!prefill_param_.use_moe) { + return ONEREC_WEIGHT_COUNT_PER_LAYER; + } else if (prefill_param_.moe_config && + !prefill_param_.moe_config->moe_use_shared_experts) { + // When MoE is enabled but shared experts are disabled, subtract shared + // expert weights Shared expert weights count: 18 weights (gateup: 6, + // down: 6, gate: 6) + const uint64_t shared_expert_weight_count = 18; + return ONEREC_MOE_WEIGHT_COUNT_PER_LAYER - shared_expert_weight_count; + } else { + return ONEREC_MOE_WEIGHT_COUNT_PER_LAYER; + } + }(); + + // Define weights that are expected to be [1] after merging + std::set merged_weights = {IN_K_WEIGHT, IN_V_WEIGHT, IN_FFN_WI_1_WEIGHT}; + if (is_decoder_) { + merged_weights.insert({IN_CROSS_K_WEIGHT, IN_CROSS_V_WEIGHT}); + } + + for (const auto& [name, index] : weight_mapping) { + auto sizes = at_weight_tensors_[index].sizes(); + bool is_placeholder = (sizes.size() == 2 && sizes[0] == 1); + bool is_expected_placeholder = merged_weights.count(index) > 0; + + // Special handling for relative_attention_bias - it's optional and only + // exists in first layer + bool is_relative_bias = (index == IN_RELATIVE_ATTENTION_BIAS_WEIGHT); + + if (is_placeholder && !is_expected_placeholder && !is_relative_bias) { + CHECK(false) << "weight is not loaded for " << prefix << name; + } + + // if (is_relative_bias && is_placeholder) { + // LOG(INFO) << "[ONEREC DEBUG] Weight " << prefix << name + // << " is placeholder (expected for non-first layers)"; + // } + } +} + +void NpuOneRecBlockLayerImpl::merge_loaded_weights() { + // Debug: Print shapes before merging + /* + LOG(INFO) << "[ONEREC DEBUG] Before merging QKV weights:"; + LOG(INFO) << "[ONEREC DEBUG] Q weight shape: [" + << at_weight_tensors_[IN_Q_WEIGHT].sizes() << "]"; + LOG(INFO) << "[ONEREC DEBUG] K weight shape: [" + << at_weight_tensors_[IN_K_WEIGHT].sizes() << "]"; + LOG(INFO) << "[ONEREC DEBUG] V weight shape: [" + << at_weight_tensors_[IN_V_WEIGHT].sizes() << "]"; + */ + // Check if weights were properly loaded (not placeholders) + bool q_loaded = !(at_weight_tensors_[IN_Q_WEIGHT].sizes().size() == 2 && + at_weight_tensors_[IN_Q_WEIGHT].sizes()[0] == 1); + bool k_loaded = !(at_weight_tensors_[IN_K_WEIGHT].sizes().size() == 2 && + at_weight_tensors_[IN_K_WEIGHT].sizes()[0] == 1); + bool v_loaded = !(at_weight_tensors_[IN_V_WEIGHT].sizes().size() == 2 && + at_weight_tensors_[IN_V_WEIGHT].sizes()[0] == 1); + /* + LOG(INFO) << "[ONEREC DEBUG] Weight loading status: Q=" + << (q_loaded ? "loaded" : "placeholder") + << ", K=" << (k_loaded ? "loaded" : "placeholder") + << ", V=" << (v_loaded ? "loaded" : "placeholder"); + */ + if (!q_loaded || !k_loaded || !v_loaded) { + LOG(ERROR) + << "[ONEREC ERROR] QKV weights not properly loaded. This will cause " + "SplitOperation to fail."; + LOG(ERROR) + << "[ONEREC ERROR] Expected weight shapes should be [hidden_size, " + "hidden_size] but got placeholders [1, hidden_size]"; + LOG(ERROR) + << "[ONEREC ERROR] Please check if the weight names in StateDict " + "match the expected mappings."; + + // For debugging purposes, let's create dummy weights with correct + // dimensions This is a temporary workaround to prevent the SplitOperation + // error + int hidden_size = at_weight_tensors_[IN_Q_WEIGHT].sizes()[1]; + int head_dim = hidden_size / 4; // Assuming 4 heads as per default config + int expected_dim = 4 * head_dim; // num_heads * head_dim + + LOG(WARNING) << "[ONEREC WARNING] Creating dummy weights with correct " + "dimensions as workaround"; + LOG(WARNING) << "[ONEREC WARNING] hidden_size=" << hidden_size + << ", expected_dim=" << expected_dim; + + if (!q_loaded) { + at_weight_tensors_[IN_Q_WEIGHT] = + torch::randn({expected_dim, hidden_size}).to(device_).to(dtype_) * + 0.02; + } + if (!k_loaded) { + at_weight_tensors_[IN_K_WEIGHT] = + torch::randn({expected_dim, hidden_size}).to(device_).to(dtype_) * + 0.02; + } + if (!v_loaded) { + at_weight_tensors_[IN_V_WEIGHT] = + torch::randn({expected_dim, hidden_size}).to(device_).to(dtype_) * + 0.02; + } + } + + // Merge Q, K, V weights for self-attention + auto new_q_weight = torch::cat({at_weight_tensors_[IN_Q_WEIGHT], + at_weight_tensors_[IN_K_WEIGHT], + at_weight_tensors_[IN_V_WEIGHT]}, + 0); + /* + LOG(INFO) << "[ONEREC DEBUG] After merging QKV weights:"; + LOG(INFO) << "[ONEREC DEBUG] Merged Q weight shape: [" << + new_q_weight.sizes() + << "]"; + */ + + at_weight_tensors_[IN_Q_WEIGHT] = new_q_weight; + at_weight_tensors_[IN_K_WEIGHT] = + torch::zeros({1, at_weight_tensors_[IN_Q_WEIGHT].size(1)}) + .to(device_) + .to(dtype_); + at_weight_tensors_[IN_V_WEIGHT] = + torch::zeros({1, at_weight_tensors_[IN_Q_WEIGHT].size(1)}) + .to(device_) + .to(dtype_); + + // For decoder, also merge cross-attention Q, K, V weights + /*if (is_decoder_) { + auto new_cross_q_weight = + torch::cat({at_weight_tensors_[IN_CROSS_Q_WEIGHT], + at_weight_tensors_[IN_CROSS_K_WEIGHT], + at_weight_tensors_[IN_CROSS_V_WEIGHT]}, + 0); + at_weight_tensors_[IN_CROSS_Q_WEIGHT] = new_cross_q_weight; + at_weight_tensors_[IN_CROSS_K_WEIGHT] = + torch::zeros({1, at_weight_tensors_[IN_CROSS_Q_WEIGHT].size(1)}) + .to(device_) + .to(dtype_); + at_weight_tensors_[IN_CROSS_V_WEIGHT] = + torch::zeros({1, at_weight_tensors_[IN_CROSS_Q_WEIGHT].size(1)}) + .to(device_) + .to(dtype_); + }*/ + + // For MoE mode, skip traditional MLP weight merging + if (!prefill_param_.use_moe) { + // Merge wi_0 and wi_1 weights for gated activation (gate_up weight pack) + auto new_gate_up_weight = + torch::cat({at_weight_tensors_[IN_FFN_WI_0_WEIGHT], + at_weight_tensors_[IN_FFN_WI_1_WEIGHT]}, + 0); + at_weight_tensors_[IN_FFN_WI_0_WEIGHT] = new_gate_up_weight; + at_weight_tensors_[IN_FFN_WI_1_WEIGHT] = + torch::zeros({1, at_weight_tensors_[IN_FFN_WI_0_WEIGHT].size(1)}) + .to(device_) + .to(dtype_); + } else { + // MoE mode: Merge expert weights similar to Qwen3 and DeepseekV2 + LOG(INFO) << "[ONEREC DEBUG] MoE mode: merging expert weights"; + + // Call merge_experts_weights to process the loaded expert weights + merge_experts_weights(); + + // Merge shared expert weights if they exist + merge_shared_experts_weights(); + + if (at_weight_tensors_[IN_MOE_EXPERT_W1_WEIGHT].numel() > 1) { + LOG(INFO) << "[ONEREC DEBUG] Expert W1 weights shape: [" + << at_weight_tensors_[IN_MOE_EXPERT_W1_WEIGHT].sizes() << "]"; + } + if (at_weight_tensors_[IN_MOE_EXPERT_W2_WEIGHT].numel() > 1) { + LOG(INFO) << "[ONEREC DEBUG] Expert W2 weights shape: [" + << at_weight_tensors_[IN_MOE_EXPERT_W2_WEIGHT].sizes() << "]"; + } + if (at_weight_tensors_[IN_MOE_EXPERT_W3_WEIGHT].numel() > 1) { + LOG(INFO) << "[ONEREC DEBUG] Expert W3 weights shape: [" + << at_weight_tensors_[IN_MOE_EXPERT_W3_WEIGHT].sizes() << "]"; + } + + // Log shared expert weights if they exist + if (at_weight_tensors_[IN_MOE_SHARED_W1_WEIGHT].numel() > 1) { + LOG(INFO) << "[ONEREC DEBUG] Shared Expert W1 weights shape: [" + << at_weight_tensors_[IN_MOE_SHARED_W1_WEIGHT].sizes() << "]"; + } + if (at_weight_tensors_[IN_MOE_SHARED_W2_WEIGHT].numel() > 1) { + LOG(INFO) << "[ONEREC DEBUG] Shared Expert W2 weights shape: [" + << at_weight_tensors_[IN_MOE_SHARED_W2_WEIGHT].sizes() << "]"; + } + } + + // Ensure all placeholder tensors have valid deviceData for ATB compatibility + int fixed_placeholders = 0; + const uint64_t weight_count = prefill_param_.use_moe + ? ONEREC_MOE_WEIGHT_COUNT_PER_LAYER + : ONEREC_WEIGHT_COUNT_PER_LAYER; + for (int i = 0; i < weight_count; ++i) { + // First check if tensor is defined (not null) + if (!at_weight_tensors_[i].defined()) { + // Create a minimal placeholder tensor for undefined tensors + at_weight_tensors_[i] = torch::zeros( + {1, 1}, torch::TensorOptions().device(device_).dtype(dtype_)); + fixed_placeholders++; + continue; + } + + auto sizes = at_weight_tensors_[i].sizes(); + if (sizes.size() == 2 && sizes[0] == 1) { + // Check if tensor has valid device data + if (!at_weight_tensors_[i].is_contiguous() || + at_weight_tensors_[i].data_ptr() == nullptr) { + // Force allocation of device memory for placeholder tensors + at_weight_tensors_[i] = + torch::ones({1, sizes[1]}, + torch::TensorOptions().device(device_).dtype(dtype_)); + fixed_placeholders++; + } + } + } + // if (fixed_placeholders > 0) { + // LOG(INFO) << "[ONEREC DEBUG] Fixed " << fixed_placeholders + // << " placeholder tensors with invalid deviceData"; + // } + + c10_npu::NPUCachingAllocator::emptyCache(); + // Choose correct weight count based on use_moe + for (int i = 0; i < weight_count; ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + } + // LOG(INFO) << "ONERECBlockLayerImpl begin init_layer: " << layer_id_ << ":" + // << is_decoder_; + init_layer(); +} + +void NpuOneRecBlockLayerImpl::load_state_dict(const StateDict& state_dict) { + // Choose appropriate weight mapping based on use_moe + const auto& weight_mapping = + [this]() -> const std::unordered_map& { + if (prefill_param_.use_moe) { + return ONEREC_DECODER_MOE_WEIGHT_MAPPING; + } else { + return is_decoder_ ? ONEREC_DECODER_WEIGHT_MAPPING + : ONEREC_ENCODER_WEIGHT_MAPPING; + } + }(); + + // Debug: Print all available weights in StateDict + LOG(INFO) << "[ONEREC DEBUG] Available weights in StateDict for " + << (is_decoder_ ? "decoder" : "encoder") << ":"; + for (const auto& [key, tensor] : state_dict) { + LOG(INFO) << "[ONEREC DEBUG] " << key << " -> shape: [" << tensor.sizes() + << "]"; + } + + // Debug: Print expected weight mappings + LOG(INFO) << "[ONEREC DEBUG] Expected weight mappings:"; + for (const auto& [name, index] : weight_mapping) { + LOG(INFO) << "[ONEREC DEBUG] " << name << " -> index: " << index; + } + + // Debug: Check actual weight name matching + LOG(INFO) << "[ONEREC DEBUG] Checking weight name matching:"; + for (const auto& [state_key, tensor] : state_dict) { + for (const auto& [mapping_name, index] : weight_mapping) { + if (absl::EndsWith(state_key, mapping_name)) { + LOG(INFO) << "[ONEREC DEBUG] MATCH: " << state_key << " matches " + << mapping_name; + } + } + } + + // Handle MoE expert weights separately if using MoE + if (prefill_param_.use_moe) { + // Process each expert weight in the state dict + for (const auto& [state_key, tensor] : state_dict) { + if (state_key.find(".ffn.experts.") != std::string::npos) { + process_expert_weights(state_dict, state_key, tensor); + } + } + + // Handle shared expert weights if present + for (const auto& [state_key, tensor] : state_dict) { + // Check for shared expert patterns + bool is_shared_expert = + (state_key.find(".ffn.shared_experts.") != std::string::npos || + state_key.find(".ffn.shared_expert.") != std::string::npos); + + if (is_shared_expert) { + // Process shared expert weights using the dedicated function + process_shared_expert_weights(state_dict, state_key, tensor); + + // Handle down_proj weights for ATB compatibility + if (state_key.find(".down_proj.weight") != std::string::npos || + state_key.find(".w2.weight") != std::string::npos) { + at_weight_tensors_[IN_MLP_DOWN_WEIGHT_SHARED_EXPERT] = tensor; + LOG(INFO) + << "[ONEREC DEBUG] Also stored shared expert down weight in " + "at_weight_tensors_, shape: [" + << tensor.sizes() << "]"; + } + } + + // // Handle shared expert gate weights + // if (state_key.find(".shared_expert_gate.weight") != std::string::npos) + // { + // at_weight_tensors_[IN_SHARED_EXPERT_GATE_WEIGHT] = tensor; + // LOG(INFO) << "[ONEREC DEBUG] Loaded shared expert gate weight, shape: + // [" + // << tensor.sizes() << "]"; + // } + } + } + + for (const auto& [name, index] : weight_mapping) { + LOG(INFO) << "[ONEREC DEBUG] Loading weight: " << name << " (index " + << index << ")" << ", " << at_weight_tensors_.size(); + auto initial_shape = at_weight_tensors_[index].sizes(); + + // Special handling for relative_attention_bias - it's optional and only + // exists in first layer + bool is_relative_bias = (index == IN_RELATIVE_ATTENTION_BIAS_WEIGHT); + bool weight_exists = false; + + // Check if the weight actually exists in state_dict + for (const auto& [state_key, tensor] : state_dict) { + if (absl::EndsWith(state_key, name)) { + weight_exists = true; + break; + } + } + + if (is_relative_bias && !weight_exists) { + LOG(INFO) << "[ONEREC DEBUG] Weight " << name << " (index " << index + << ") SKIPPED: not present in this layer (expected for " + "non-first layers)"; + continue; + } + + if (ONEREC_WEIGHT_SHARD.find(index) != ONEREC_WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, ONEREC_WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + auto final_shape = at_weight_tensors_[index].sizes(); + + // Debug: Check if weight was actually loaded + bool was_loaded = !(final_shape.size() == 2 && final_shape[0] == 1 && + initial_shape == final_shape); + LOG(INFO) << "[ONEREC DEBUG] Weight " << name << " (index " << index + << ") loaded: " << (was_loaded ? "YES" : "NO") << ", shape: [" + << final_shape << "]" << ", weight exists: " << weight_exists; + } + LOG(INFO) << "ONERECBlockLayerImpl end load state dict"; +} + +int64_t NpuOneRecBlockLayerImpl::init_layer() { + // init_attn_mask(); + name_ = + is_decoder_ ? "onerec_decoder_block_layer" : "onerec_encoder_block_layer"; + modelName_ = "onerec"; + LOG(INFO) << "begin init prefill param: " << prefill_param_.isPrefill + << " is_decoder_: " << is_decoder_ << " layer_id_: " << layer_id_; + CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); + // LOG(INFO) << "after init prefill param: " << decode_param_.isPrefill + // << " is_decoder_: " << is_decoder_; + // For ONEREC decoder, only use prefill_node_ for both prefill and decode + // stages if (is_decoder_) { + // LOG(INFO) << "begin init decode param: " << decode_param_.isPrefill; + + // // Initialize decoder_prefill_only_decode_node if + // enable_onerec_prefill_only is + // // true + // if (FLAGS_enable_rec_prefill_only) { + // LOG(INFO) << "begin init decoder_prefill_only_decode_param"; + // CHECK_OPERATION_STATUS_RETURN( + // init_node(decoder_prefill_only_decode_node_, + // decoder_prefill_only_decode_param_)); + // } else { + // CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_)); + // } + // } + return atb::NO_ERROR; +} + +int64_t NpuOneRecBlockLayerImpl::init_attn_mask() { + // attn_mask is now preprocessed in ONERECStack, no local initialization + // needed + return atb::NO_ERROR; +} + +int64_t NpuOneRecBlockLayerImpl::init_node( + atb_speed::Model::Node& node, + atb_speed::onerec::BlockLayerParam& param) { + atb::Operation* operation = nullptr; + atb::Status status = atb_speed::onerec::BlockLayer(param, &operation); + if (status != atb::NO_ERROR) { + LOG(ERROR) << "Failed to create ONEREC BlockLayer operation, status: " + << status; + return status; + } + + node.operation.reset(operation); + if (node.operation == nullptr) { + LOG(ERROR) << "node.operation is null after creation"; + return -1; + } + + uint32_t inputNum = node.operation->GetInputNum(); + uint32_t outputNum = node.operation->GetOutputNum(); + + // Debug logging for ONEREC tensor count issue + LOG(INFO) << "[ONEREC DEBUG] " << modelName_ + << " - ATB operation inputNum: " << inputNum + << ", outputNum: " << outputNum << ", is_decoder: " << is_decoder_; + + if (inputNum < 1) { + LOG(ERROR) << "Invalid input number: " << inputNum; + return -1; + } + + // For ONEREC encoder, we need at least 84 input tensors (79 weights + 5 + // non-weights from GetONERECEncoderTensorNames()) For ONEREC decoder, we need + // even more tensors + uint32_t required_tensors; + if (is_decoder_) { + // For decoder: check if using MoE variant + if (param.use_moe) { + required_tensors = + ONEREC_MOE_WEIGHT_COUNT_PER_LAYER + 18; // 97 + 4 + 14 = 115 + } else { + required_tensors = ONEREC_WEIGHT_COUNT_PER_LAYER + 14; // 79 + 14 = 93 + } + } else { + required_tensors = 84; // Encoder: 79 weights + 5 non-weights + } + if (inputNum < required_tensors) { + LOG(WARNING) << "[ONEREC DEBUG] " << modelName_ + << " - ATB operation provides only " << inputNum + << " input tensors, but we need at least " << required_tensors + << " tensors. This may cause out_of_range errors."; + } + + node.inTensors.resize(inputNum); + node.outTensors.resize(outputNum); + + // Set weight tensors + const uint64_t weight_count = prefill_param_.use_moe + ? ONEREC_MOE_WEIGHT_COUNT_PER_LAYER + : ONEREC_WEIGHT_COUNT_PER_LAYER; + for (size_t weightTensorId = 0; weightTensorId < weight_count; + ++weightTensorId) { + if (weightTensorId < inputNum) { + node.inTensors.at(weightTensorId) = &atb_weight_tensors_[weightTensorId]; + } + } + + node.variantPack.inTensors.reserve(inputNum); + node.variantPack.inTensors.resize(inputNum); + node.variantPack.outTensors.reserve(outputNum); + node.variantPack.outTensors.resize(outputNum); + + return atb::NO_ERROR; +} + +torch::Tensor NpuOneRecBlockLayerImpl::forward( + torch::Tensor& x, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + atb::Context* context, + AtbWorkspace& workspace, + std::vector event, + std::vector*> event_flag, + torch::Tensor* encoder_output, + int node_id, + const torch::Tensor& expert_array) { + atb::Status st; + + // Update BlockLayerParam with current ModelInputParams + + //@TODO: delete + // prefill_param_.inputParams = &input_params; + // decode_param_.inputParams = &input_params; + + if (input_params.rec_params && input_params.rec_params->rec_stage == + RecModelInputParams::RecStage::PREFILL) { + // Prefill stage + if (is_decoder_) { + if (FLAGS_enable_rec_prefill_only) { + if (prefill_param_.use_moe) { + build_decoder_moe_node_variant_pack(prefill_node_, + x, + attn_mask, + kv_cache, + input_params, + true, + encoder_output, + node_id, + expert_array); + } else { + build_decoder_node_variant_pack(prefill_node_, + x, + attn_mask, + kv_cache, + input_params, + true, + encoder_output, + node_id); + } + st = execute_node(prefill_node_, node_id, event, event_flag); + LOG_IF(FATAL, st != 0) + << modelName_ << " execute prefill layer fail, error code: " << st; + } + } else { + // Encoder prefill + build_encoder_node_variant_pack( + prefill_node_, x, attn_mask, input_params, true, node_id); + st = execute_node(prefill_node_, node_id, event, event_flag); + LOG_IF(FATAL, st != 0) + << modelName_ << " execute prefill layer fail, error code: " << st; + } + } else { + // Decode stage + if (is_decoder_) { + if (decode_param_.use_moe) { + build_decoder_moe_node_variant_pack(decode_node_, + x, + attn_mask, + kv_cache, + input_params, + false, + encoder_output, + node_id, + expert_array); + } else { + build_decoder_node_variant_pack(decode_node_, + x, + attn_mask, + kv_cache, + input_params, + false, + encoder_output, + node_id); + } + st = execute_node(decode_node_, node_id + 1000, event, event_flag); + LOG_IF(FATAL, st != 0) + << modelName_ << " execute decode layer fail, error code: " << st; + } else { + LOG(FATAL) << modelName_ << " encoder decode stage is not supported."; + } + } + + return at_placeholder; +} + +void NpuOneRecBlockLayerImpl::build_encoder_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + ModelInputParams& input_params, + bool is_prefill, + int layer_id) { + // ONEREC Encoder uses simplified tensor list, corresponding to + // onerec_encoder_input configuration onerec_encoder_input: {"in_input", + // "in_attention_mask", "in_seq_len", "in_token_offset", "in_layer_id"} Total: + // 79 weights + 5 non-weights = 84 tensors + + internalTensors = atb_speed::Utils::AtTensor2Tensor(x); + + // Debug logging for tensor array sizes + /* + LOG(INFO) << "[ONEREC DEBUG] build_encoder_node_variant_pack - " + "variantPack.inTensors.size(): " + << node.variantPack.inTensors.size() + << ", ONEREC_WEIGHT_COUNT_PER_LAYER: " << + ONEREC_WEIGHT_COUNT_PER_LAYER; + */ + + // Weight tensors (indices 0-78) + for (size_t i = 0; i < ONEREC_WEIGHT_COUNT_PER_LAYER; ++i) { + CHECK_THROW(node.inTensors.at(i) == nullptr, + modelName_ << "inTensor " << i << "is NULL"); + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + } + + // Non-weight tensor index definitions - based on onerec_encoder_input + // configuration + const int INPUT_TENSOR_IDX = ONEREC_WEIGHT_COUNT_PER_LAYER; // 79: "in_input" + const int ATTENTION_MASK_IDX = + INPUT_TENSOR_IDX + 1; // 80: "in_attention_mask" + const int TOKEN_OFFSET_IDX = ATTENTION_MASK_IDX + 1; // 81: "in_token_offset" + const int LAYER_ID_IDX = TOKEN_OFFSET_IDX + 1; // 82: "in_layer_id" + const int SEQ_LEN_IDX = LAYER_ID_IDX + 1; // 83: "in_seq_len" + + // "in_input" - Critical tensor + node.variantPack.inTensors.at(INPUT_TENSOR_IDX) = internalTensors; + + // "in_attention_mask" - Critical tensor (now pre-processed in ONERECStack) + // attn_mask is already contiguous and on correct device from ONERECStack + // preprocessing + node.variantPack.inTensors.at(ATTENTION_MASK_IDX) = + atb_speed::Utils::AtTensor2Tensor(attn_mask); + + // "in_token_offset" - Set to placeholder + node.variantPack.inTensors.at(TOKEN_OFFSET_IDX) = placeholder; + node.variantPack.inTensors.at(TOKEN_OFFSET_IDX).hostData = + placeholder_vec_.data(); + + // "in_layer_id" - Set to placeholder + node.variantPack.inTensors.at(LAYER_ID_IDX) = placeholder; + node.variantPack.inTensors.at(LAYER_ID_IDX).hostData = + placeholder_vec_.data(); + + // "in_seq_len" - Important tensor (now pre-processed in ONERECStack) + if (input_params.rec_params && + input_params.rec_params->encoder_seq_lens_tensor.defined()) { + // encoder_seq_lens_tensor is already contiguous and on correct device from + // ONERECStack preprocessing + node.variantPack.inTensors.at(SEQ_LEN_IDX) = + atb_speed::Utils::AtTensor2Tensor( + input_params.rec_params->encoder_seq_lens_tensor); + node.variantPack.inTensors.at(SEQ_LEN_IDX).hostData = + input_params.rec_params->encoder_seq_lens.data(); + } else if (input_params.rec_params) { + // Use placeholder to avoid tensor creation and sync + node.variantPack.inTensors.at(SEQ_LEN_IDX) = placeholder; + node.variantPack.inTensors.at(SEQ_LEN_IDX).hostData = + input_params.rec_params->encoder_seq_lens.data(); + } + + // Set output tensor + node.variantPack.outTensors.at(0) = internalTensors; +} + +void NpuOneRecBlockLayerImpl::build_decoder_moe_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill, + torch::Tensor* encoder_output, + int layer_id, + const torch::Tensor& expert_array) { + // ONEREC Decoder MoE tensor mapping - must match the complete tensor list + // from ConstructTensorMap for MoE configuration + + // Copy all weight tensors from node.inTensors to variantPack.inTensors + // For MoE, use ONEREC_MOE_WEIGHT_COUNT_PER_LAYER (74) instead of + // ONEREC_WEIGHT_COUNT_PER_LAYER (79) + for (size_t i = 0; i < ONEREC_MOE_WEIGHT_COUNT_PER_LAYER; ++i) { + CHECK_THROW(node.inTensors.at(i) == nullptr, + modelName_ << "inTensor " << i << " is NULL"); + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + } + + // Configure non-weight tensors starting from index + // ONEREC_MOE_WEIGHT_COUNT_PER_LAYER (74) Must match the exact order from + // GetONERECLayerInTensorCandidates in block_layer.cpp + // Start after weights and MoE routing tensors (expert_array, expert_group, + // one_hot, zero_hot) + setup_common_decoder_tensors(node, + x, + attn_mask, + input_params, + encoder_output, + ONEREC_MOE_WEIGHT_COUNT_PER_LAYER + 4); + + // Add MoE-specific tensors (expert_array, expert_group, one_hot, zero_hot) + // These tensors are required by ONEREC MoE kernel implementation + // They should directly follow the weight tensors as defined in kernel + int moe_tensor_start = + ONEREC_MOE_WEIGHT_COUNT_PER_LAYER; // Directly after weights + + // Set expert_array tensor + if (expert_array.defined()) { + node.variantPack.inTensors.at(moe_tensor_start) = + atb_speed::Utils::AtTensor2Tensor(expert_array); + } + + // Set expert_group_, one_hot_, zero_hot_ tensors + if (expert_group_.defined()) { + node.variantPack.inTensors.at(moe_tensor_start + 1) = + atb_speed::Utils::AtTensor2Tensor(expert_group_); + } + if (one_hot_.defined()) { + node.variantPack.inTensors.at(moe_tensor_start + 2) = + atb_speed::Utils::AtTensor2Tensor(one_hot_); + } + if (zero_hot_.defined()) { + node.variantPack.inTensors.at(moe_tensor_start + 3) = + atb_speed::Utils::AtTensor2Tensor(zero_hot_); + } +} + +// Private helper function to set common decoder tensors +int NpuOneRecBlockLayerImpl::setup_common_decoder_tensors( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + ModelInputParams& input_params, + torch::Tensor* encoder_output, + int start_tensor_idx) { + // Create internal tensor from input + internalTensors = atb_speed::Utils::AtTensor2Tensor(x); + + int idx = start_tensor_idx; + + // Input tensor + node.variantPack.inTensors.at(idx++) = internalTensors; + + // Attention mask + node.variantPack.inTensors.at(idx++) = + atb_speed::Utils::AtTensor2Tensor(attn_mask); + + // KV cache placeholders + node.variantPack.inTensors.at(idx++) = placeholder; + node.variantPack.inTensors.at(idx++) = placeholder; + + // Sequence length + if (input_params.kv_seq_lens.defined()) { + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); + node.variantPack.inTensors.at(idx).hostData = + input_params.kv_seq_lens_vec.data(); + } else { + int32_t seq_len = std::max(static_cast(x.size(0)), 1); + seq_lens_vec_ = {seq_len}; + auto seq_lens_tensor = torch::tensor( + seq_lens_vec_, + torch::TensorOptions().dtype(torch::kInt32).device(device_)); + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(seq_lens_tensor); + node.variantPack.inTensors.at(idx).hostData = seq_lens_vec_.data(); + } + idx++; + + // Token offset and layer id placeholders + node.variantPack.inTensors.at(idx) = placeholder; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + node.variantPack.inTensors.at(idx) = placeholder; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + + // Block tables + if (!FLAGS_enable_rec_prefill_only && input_params.block_tables.defined()) { + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(input_params.block_tables); + } else { + node.variantPack.inTensors.at(idx) = placeholder; + node.variantPack.inTensors.at(idx).hostData = placeholder_vec_.data(); + } + idx++; + + // Cache slots + if (!FLAGS_enable_rec_prefill_only && + input_params.new_cache_slots.defined()) { + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); + } else { + node.variantPack.inTensors.at(idx) = placeholder; + node.variantPack.inTensors.at(idx).hostData = placeholder_vec_.data(); + } + idx++; + + // Encoder output + if (encoder_output != nullptr) { + encoder_output_contiguous_ = encoder_output->is_contiguous() + ? *encoder_output + : encoder_output->contiguous(); + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(encoder_output_contiguous_); + } else { + node.variantPack.inTensors.at(idx) = placeholder; + } + idx++; + + // Cross attention placeholders + for (int i = 0; i < 3; i++) { + node.variantPack.inTensors.at(idx) = placeholder; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + } + + // Encoder sequence length + if (input_params.rec_params) { + node.variantPack.inTensors.at(idx) = atb_speed::Utils::AtTensor2Tensor( + input_params.rec_params->encoder_seq_lens_tensor); + node.variantPack.inTensors.at(idx++).hostData = + input_params.rec_params->encoder_seq_lens.data(); + } + + // Setup output tensor + node.variantPack.outTensors.at(0) = internalTensors; + return idx; +} + +void NpuOneRecBlockLayerImpl::build_decoder_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill, + torch::Tensor* encoder_output, + int layer_id) { + // ONEREC Decoder tensor mapping - must match the complete tensor list from + // ConstructTensorMap The operation expects all tensors configured by + // ConstructTensorMap, including optional features + + // Copy all weight tensors from node.inTensors to variantPack.inTensors + // The first 79 positions are for weight tensors - use same approach as + // encoder + for (size_t i = 0; i < ONEREC_WEIGHT_COUNT_PER_LAYER; ++i) { + CHECK_THROW(node.inTensors.at(i) == nullptr, + modelName_ << "inTensor " << i << " is NULL"); + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + } + + // Use common tensor setup function for shared logic + // All common tensors (including encoder_output and cross-attention tensors) + // are handled here + int tensor_idx = setup_common_decoder_tensors(node, + x, + attn_mask, + input_params, + encoder_output, + ONEREC_WEIGHT_COUNT_PER_LAYER); + + // Fill remaining tensors with placeholders (for optional features like lora, + // kv_quant, etc.) + // Record the number of filled placeholders + int placeholder_count = 0; + while (tensor_idx < node.variantPack.inTensors.size()) { + node.variantPack.inTensors.at(tensor_idx) = placeholder; + node.variantPack.inTensors.at(tensor_idx).hostData = + placeholder_vec_.data(); + tensor_idx++; + placeholder_count++; + } + /* + LOG(INFO) << "[ONEREC DEBUG] total fill " << placeholder_count << " + placeholders " + << tensor_idx << ":" << node.variantPack.inTensors.size(); + */ + + // Final validation: Check for tensors without deviceData + int invalid_tensors = 0; + for (size_t i = 0; i < node.variantPack.inTensors.size(); ++i) { + const auto& tensor = node.variantPack.inTensors.at(i); + if (!tensor.deviceData) { + LOG(ERROR) << "Input tensor[" << i << "] has no deviceData!"; + invalid_tensors++; + } + } + + if (invalid_tensors > 0) { + LOG(ERROR) + << "Found " << invalid_tensors + << " tensors without deviceData, this may cause ATB setup to fail."; + } +} + +void NpuOneRecBlockLayerImpl::resize_experts_weights( + int num_of_device_experts) { + experts_weights_["gate_proj.weight"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight"] = + std::vector(num_of_device_experts); + + // Initialize quantization weights if needed + experts_weights_["gate_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["gate_proj.weight_scale"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight_scale"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight_scale"] = + std::vector(num_of_device_experts); +} + +void NpuOneRecBlockLayerImpl::process_expert_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + std::lock_guard lock(experts_mutex_); + + int expert_id = extract_expert_index(name); + if (expert_id < 0) { + return; + } + + // Calculate local expert index (similar to Qwen3 and DeepSeek V2) + const int local_index = expert_id % num_experts_per_partition_; + + std::string weight_suffix = extract_endswith(name); + + // Map ONEREC weight names to standard names and use 2D indexing + std::string suffix; + if (weight_suffix == "gate_proj.weight" || weight_suffix == "w1.weight") { + suffix = "gate_proj.weight"; + } else if (weight_suffix == "up_proj.weight" || + weight_suffix == "w3.weight") { + suffix = "up_proj.weight"; + } else if (weight_suffix == "down_proj.weight" || + weight_suffix == "w2.weight") { + suffix = "down_proj.weight"; + } else if (weight_suffix == "gate_proj.weight_offset" || + weight_suffix == "w1.weight_offset") { + suffix = "gate_proj.weight_offset"; + } else if (weight_suffix == "up_proj.weight_offset" || + weight_suffix == "w3.weight_offset") { + suffix = "up_proj.weight_offset"; + } else if (weight_suffix == "down_proj.weight_offset" || + weight_suffix == "w2.weight_offset") { + suffix = "down_proj.weight_offset"; + } else if (weight_suffix == "gate_proj.weight_scale" || + weight_suffix == "w1.weight_scale") { + suffix = "gate_proj.weight_scale"; + } else if (weight_suffix == "up_proj.weight_scale" || + weight_suffix == "w3.weight_scale") { + suffix = "up_proj.weight_scale"; + } else if (weight_suffix == "down_proj.weight_scale" || + weight_suffix == "w2.weight_scale") { + suffix = "down_proj.weight_scale"; + } else { + LOG(WARNING) << "[ONEREC WARNING] Unknown expert weight suffix: " + << weight_suffix; + return; + } + + // Use 2D indexing like Qwen3 and DeepSeek V2 + experts_weights_[suffix][local_index] = tensor.clone(); + + LOG(INFO) << "[ONEREC DEBUG] Stored expert " << expert_id + << " (local_index: " << local_index << ") weight: " << suffix + << ", shape: [" << tensor.sizes() << "]"; +} + +void NpuOneRecBlockLayerImpl::process_shared_expert_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + LOG(INFO) << "[ONEREC DEBUG] Processing shared expert weight: " << name + << ", shape: [" << tensor.sizes() << "]"; + + torch::Tensor tmp_tensor = tensor.to(device_); + + // Determine which shared expert weight this is + if (absl::StrContains(name, "gate_proj") || absl::StrContains(name, "w1")) { + shared_expert_gate_weights_.push_back(tmp_tensor); + LOG(INFO) << "[ONEREC DEBUG] Added shared expert gate weight, total: " + << shared_expert_gate_weights_.size(); + } else if (absl::StrContains(name, "up_proj") || + absl::StrContains(name, "w3")) { + shared_expert_up_weights_.push_back(tmp_tensor); + LOG(INFO) << "[ONEREC DEBUG] Added shared expert up weight, total: " + << shared_expert_up_weights_.size(); + } else if (absl::StrContains(name, "down_proj") || + absl::StrContains(name, "w2")) { + shared_expert_down_weights_.push_back(tmp_tensor); + LOG(INFO) << "[ONEREC DEBUG] Added shared expert down weight, total: " + << shared_expert_down_weights_.size(); + } else { + LOG(WARNING) << "[ONEREC WARNING] Unknown shared expert weight type: " + << name; + } +} + +int NpuOneRecBlockLayerImpl::extract_expert_index(const std::string& name) { + // Extract expert index from patterns like "experts.0.w1" or "experts.15.w2" + size_t experts_pos = name.find(".experts."); + if (experts_pos == std::string::npos) { + return -1; + } + + size_t start_pos = experts_pos + 9; // length of ".experts." + size_t end_pos = name.find(".", start_pos); + if (end_pos == std::string::npos) { + return -1; + } + + try { + return std::stoi(name.substr(start_pos, end_pos - start_pos)); + } catch (const std::exception& e) { + LOG(WARNING) << "[ONEREC DEBUG] Failed to extract expert index from: " + << name; + return -1; + } +} + +std::string NpuOneRecBlockLayerImpl::extract_endswith( + const std::string& input) { + // Find the last occurrence of "experts.{number}." + size_t experts_pos = input.find(".experts."); + if (experts_pos == std::string::npos) { + return ""; + } + + // Find the next dot after experts.{number} + size_t start_pos = experts_pos + 9; // length of ".experts." + size_t next_dot = input.find(".", start_pos); + if (next_dot == std::string::npos) { + return ""; + } + + // Extract everything after "experts.{number}." + return input.substr(next_dot + 1); +} + +// Implementation of merge_experts_weights functions - reference from Qwen3 +void NpuOneRecBlockLayerImpl::merge_experts_weights() { + LOG(INFO) << "[ONEREC DEBUG] merge_experts_weights begin"; + + // Check if required weights exist + if (experts_weights_.count("gate_proj.weight") == 0 || + experts_weights_.count("up_proj.weight") == 0 || + experts_weights_.count("down_proj.weight") == 0) { + LOG(WARNING) + << "[ONEREC DEBUG] Missing required expert weights, skipping merge"; + return; + } + + LOG(INFO) << "[ONEREC DEBUG] merge gate_proj " + << experts_weights_["gate_proj.weight"].size() + << " and up_proj weights: " + << experts_weights_["up_proj.weight"].size(); + try { + // Convert 2D experts_weights_ to 1D vectors for merging + std::vector gate_weights_1d; + std::vector up_weights_1d; + + // Extract valid tensors from 2D structure + for (const auto& tensor : experts_weights_["gate_proj.weight"]) { + if (tensor.defined()) { + gate_weights_1d.push_back(tensor); + } + } + for (const auto& tensor : experts_weights_["up_proj.weight"]) { + if (tensor.defined()) { + up_weights_1d.push_back(tensor); + } + } + + LOG(INFO) << "[ONEREC DEBUG] Extracted " << gate_weights_1d.size() + << " gate weights and " << up_weights_1d.size() << " up weights"; + + torch::Tensor mlp_gateup_weight; + if (quantize_type_.compare("w8a8_dynamic") == 0) { + LOG(INFO) << "w8a8_dynamic"; + mlp_gateup_weight = merge_experts_weights( + gate_weights_1d, up_weights_1d, /*transpose=*/true); + + if (experts_weights_.count("gate_proj.weight_offset") > 0 && + experts_weights_.count("up_proj.weight_offset") > 0) { + std::vector gate_offset_1d, up_offset_1d; + for (const auto& tensor : experts_weights_["gate_proj.weight_offset"]) { + if (tensor.defined()) gate_offset_1d.push_back(tensor); + } + for (const auto& tensor : experts_weights_["up_proj.weight_offset"]) { + if (tensor.defined()) up_offset_1d.push_back(tensor); + } + at_weight_tensors_[IN_MOE_EXPERT_W1_WEIGHT] = + merge_experts_weights(gate_offset_1d, up_offset_1d); + } + + if (experts_weights_.count("gate_proj.weight_scale") > 0 && + experts_weights_.count("up_proj.weight_scale") > 0) { + std::vector gate_scale_1d, up_scale_1d; + for (const auto& tensor : experts_weights_["gate_proj.weight_scale"]) { + if (tensor.defined()) gate_scale_1d.push_back(tensor); + } + for (const auto& tensor : experts_weights_["up_proj.weight_scale"]) { + if (tensor.defined()) up_scale_1d.push_back(tensor); + } + at_weight_tensors_[IN_MOE_EXPERT_W3_WEIGHT] = + merge_experts_weights(gate_scale_1d, up_scale_1d); + } + } else { + LOG(INFO) << "w8a8_static"; + mlp_gateup_weight = merge_experts_weights( + gate_weights_1d, up_weights_1d, /*transpose=*/false); + } + at_weight_tensors_[IN_MOE_EXPERT_W1_WEIGHT] = + at_npu::native::npu_format_cast(mlp_gateup_weight, 2).contiguous(); + } catch (const std::exception& e) { + LOG(ERROR) << "[ERROR] Exception in gateup weight processing: " << e.what(); + throw; + } + + LOG(INFO) << "[ONEREC DEBUG] merge down_proj " + << experts_weights_["down_proj.weight"].size(); + try { + // Convert 2D down_proj weights to 1D vector + std::vector down_weights_1d; + for (const auto& tensor : experts_weights_["down_proj.weight"]) { + if (tensor.defined()) { + down_weights_1d.push_back(tensor); + } + } + + LOG(INFO) << "[ONEREC DEBUG] Extracted " << down_weights_1d.size() + << " down weights"; + + torch::Tensor mlp_down_weight = + merge_experts_weights(down_weights_1d, /*transpose=*/false); + + at_weight_tensors_[IN_MOE_EXPERT_W2_WEIGHT] = + at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); + + if (quantize_type_.compare("w8a8_dynamic") == 0) { + if (experts_weights_.count("down_proj.weight_offset") > 0) { + std::vector down_offset_1d; + for (const auto& tensor : experts_weights_["down_proj.weight_offset"]) { + if (tensor.defined()) down_offset_1d.push_back(tensor); + } + // Use a different tensor index for offset + at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = + merge_experts_weights(down_offset_1d); + } + if (experts_weights_.count("down_proj.weight_scale") > 0) { + std::vector down_scale_1d; + for (const auto& tensor : experts_weights_["down_proj.weight_scale"]) { + if (tensor.defined()) down_scale_1d.push_back(tensor); + } + // Use a different tensor index for scale + at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = + merge_experts_weights(down_scale_1d); + } + } + } catch (const std::exception& e) { + LOG(ERROR) << "[ERROR] Exception in down weight processing: " << e.what(); + throw; + } + LOG(INFO) << "[ONEREC DEBUG] end merge_experts_weights()"; +} + +torch::Tensor NpuOneRecBlockLayerImpl::merge_experts_weights( + std::vector& experts, + bool transpose) { + LOG(INFO) << "[ONEREC DEBUG] merge_experts_weights, experts size: " + << experts.size(); + torch::Tensor merged_tensor = torch::stack(experts, 0).to(device_); + // Bypass torch::stack operation, generate random tensor with correct shape + // torch::Tensor merged_tensor; + // if (!experts.empty()) { + // // Calculate merged shape based on the first expert's shape + // auto expert_shape = experts[0].sizes().vec(); + // std::vector merged_shape = + // {static_cast(experts.size())}; merged_shape.insert( + // merged_shape.end(), expert_shape.begin(), expert_shape.end()); + + // // Generate random tensor, maintaining data type and device + // merged_tensor = torch::randn(merged_shape, + // torch::TensorOptions() + // .dtype(experts[0].dtype()) + // .device(experts[0].device())) + // .to(device_); + // LOG(INFO) << "[ONEREC DEBUG] Generated random merged tensor with shape: " + // << merged_tensor.sizes(); + // } + + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + merged_tensor = merged_tensor.contiguous(); + experts.clear(); + LOG(INFO) << "[ONEREC DEBUG] merge_experts_weights, return tensor size: " + << merged_tensor.sizes(); + return merged_tensor; +} + +torch::Tensor NpuOneRecBlockLayerImpl::merge_experts_weights( + std::vector& experts_gate, + std::vector& experts_up, + bool transpose) { + LOG(INFO) << "[ONEREC DEBUG] merge_experts_weights, gate size: " + << experts_gate.size() << " up size: " << experts_up.size(); + for (size_t i = 0; i < experts_up.size(); ++i) { + experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); + } + torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device_); + // Bypass memory issues: directly generate a random tensor with correct shape + // torch::Tensor merged_tensor; + + // if (!experts_gate.empty() && !experts_up.empty()) { + // // Calculate the correct merged shape + // auto gate_sizes = experts_gate[0].sizes(); + // auto up_sizes = experts_up[0].sizes(); + + // // Calculate concatenated dimension 0 size (gate + up) + // int64_t concat_dim0 = gate_sizes[0] + up_sizes[0]; + + // // Build final tensor shape: [num_experts, concat_dim0, other_dims...] + // std::vector final_shape; + // final_shape.push_back(experts_gate.size()); // num_experts + // final_shape.push_back(concat_dim0); // dimension 0 of gate + up + // for (int i = 1; i < gate_sizes.size(); ++i) { + // final_shape.push_back(gate_sizes[i]); + // } + + // // Generate random tensor with correct data type and device + // auto dtype = experts_gate[0].dtype(); + // merged_tensor = torch::randn( + // final_shape, torch::TensorOptions().dtype(dtype).device(device_)); + // } + + LOG(INFO) << "[ONEREC DEBUG] Generated random tensor with shape: " + << merged_tensor.sizes(); + + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + merged_tensor = merged_tensor.contiguous(); + experts_gate.clear(); + experts_up.clear(); + LOG(INFO) << "[ONEREC DEBUG] merge_experts_weights, return tensor size: " + << merged_tensor.sizes(); + return merged_tensor; +} + +void NpuOneRecBlockLayerImpl::merge_shared_experts_weights() { + LOG(INFO) << "[ONEREC DEBUG] merge_shared_experts_weights called"; + + // Check if we have shared expert weights to merge + if (shared_expert_gate_weights_.empty() && + shared_expert_up_weights_.empty() && + shared_expert_down_weights_.empty()) { + LOG(INFO) << "[ONEREC DEBUG] No shared expert weights to merge"; + return; + } + + // Merge shared expert gate and up weights (similar to regular experts) + if (!shared_expert_gate_weights_.empty() && + !shared_expert_up_weights_.empty()) { + LOG(INFO) << "[ONEREC DEBUG] Merging shared expert gate and up weights, " + "gate size: " + << shared_expert_gate_weights_.size() + << ", up size: " << shared_expert_up_weights_.size(); + + // Concatenate gate and up weights for shared expert + auto merged_gate_up = merge_experts_weights( + shared_expert_gate_weights_, shared_expert_up_weights_, false); + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT] = merged_gate_up; + + LOG(INFO) << "[ONEREC DEBUG] Shared expert gate+up merged tensor shape: " + << merged_gate_up.sizes(); + } else if (!shared_expert_gate_weights_.empty()) { + // Only gate weights available + auto merged_gate = + merge_experts_weights(shared_expert_gate_weights_, false); + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT] = merged_gate; + + LOG(INFO) << "[ONEREC DEBUG] Shared expert gate merged tensor shape: " + << merged_gate.sizes(); + } + + // Merge shared expert down weights + if (!shared_expert_down_weights_.empty()) { + LOG(INFO) << "[ONEREC DEBUG] Merging shared expert down weights, size: " + << shared_expert_down_weights_.size(); + + auto merged_down = + merge_experts_weights(shared_expert_down_weights_, false); + at_weight_tensors_[IN_MLP_DOWN_WEIGHT_SHARED_EXPERT] = merged_down; + + LOG(INFO) << "[ONEREC DEBUG] Shared expert down merged tensor shape: " + << merged_down.sizes(); + } + + // Clear the temporary storage vectors + shared_expert_gate_weights_.clear(); + shared_expert_up_weights_.clear(); + shared_expert_down_weights_.clear(); + + LOG(INFO) << "[ONEREC DEBUG] merge_shared_experts_weights completed"; +} +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/npu_onerec_block_layer_impl.h b/xllm/core/layers/npu/npu_onerec_block_layer_impl.h new file mode 100644 index 00000000..4a0bc848 --- /dev/null +++ b/xllm/core/layers/npu/npu_onerec_block_layer_impl.h @@ -0,0 +1,167 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "framework/model/model_input_params.h" +#include "framework/model_context.h" +#include "framework/state_dict/state_dict.h" +#include "npu_base_layer.h" +#include "xllm_kernels/core/include/atb_speed/base/hosttensor_binder.h" +#include "xllm_kernels/core/include/atb_speed/base/model.h" +#include "xllm_kernels/core/include/atb_speed/log.h" +#include "xllm_kernels/core/include/atb_speed/utils/model_factory.h" +#include "xllm_kernels/models/onerec/layer/block_layer.h" +#include "xllm_kernels/operations/fusion/utils.h" + +namespace xllm { +namespace layer { + +class NpuOneRecBlockLayerImpl : public NpuBaseLayer { + public: + explicit NpuOneRecBlockLayerImpl(const ModelContext& context, + bool is_decoder = false, + int layer_id = 0); + + ~NpuOneRecBlockLayerImpl() {}; + + virtual void load_state_dict(const StateDict& state_dict) override; + + void verify_loaded_weights(const std::string& prefix) const; + + virtual void merge_loaded_weights() override; + + virtual int64_t init_layer() override; + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + atb::Context* context, + AtbWorkspace& workspace, + std::vector event, + std::vector*> event_flag, + torch::Tensor* encoder_output = nullptr, + int node_id = 0, + const torch::Tensor& expert_array = torch::Tensor()); + + private: + void param_from_args(atb_speed::onerec::BlockLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args, + bool isPrefill, + const ModelInputParams* input_params = nullptr); + + void build_encoder_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + ModelInputParams& input_params, + bool is_prefill, + int layer_id = 0); + + void build_decoder_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill, + torch::Tensor* encoder_output = nullptr, + int layer_id = 0); + + void build_decoder_moe_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill, + torch::Tensor* encoder_output = nullptr, + int layer_id = 0, + const torch::Tensor& expert_array = torch::Tensor()); + + int64_t init_node(atb_speed::Model::Node& node, + atb_speed::onerec::BlockLayerParam& param); + int64_t init_attn_mask(); + int setup_common_decoder_tensors(atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + ModelInputParams& input_params, + torch::Tensor* encoder_output = nullptr, + int start_tensor_idx = 0); + + // Expert weights processing functions + void resize_experts_weights(int num_of_device_experts); + void process_expert_weights(const StateDict& state_dict, + const std::string& state_key, + const torch::Tensor& tensor); + void process_shared_expert_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + void merge_experts_weights(); + void merge_shared_experts_weights(); + torch::Tensor merge_experts_weights(std::vector& experts, + bool transpose = false); + torch::Tensor merge_experts_weights(std::vector& experts_gate, + std::vector& experts_up, + bool transpose = false); + int extract_expert_index(const std::string& name); + std::string extract_endswith(const std::string& input); + + atb_speed::Model::Node prefill_node_; + atb_speed::Model::Node decode_node_; + atb_speed::Model::Node decoder_prefill_only_decode_node_; + std::string modelName_; + atb_speed::onerec::BlockLayerParam prefill_param_; + atb_speed::onerec::BlockLayerParam decode_param_; + atb_speed::onerec::BlockLayerParam decoder_prefill_only_decode_param_; + atb::Tensor internalTensors; + atb::Tensor placeholder; + + at::Tensor encoder_output_contiguous_; // Cache contiguous encoder_output to + // avoid repeated contiguous() calls + at::Tensor at_placeholder; + std::vector seq_lens_vec_; // Store sequence lengths for hostData + std::vector placeholder_vec_; // Store placeholder data for hostData + std::vector encoder_seq_lens_vec_; + + int device_id_; + bool is_decoder_; + int layer_id_; + + // MoE expert weights storage + std::unordered_map> experts_weights_; + std::mutex experts_mutex_; + int start_expert_id_; + int end_expert_id_; + int num_experts_per_partition_; + int ep_size_; + int ep_local_tp_rank_; + int ep_local_tp_size_; + + // Shared expert weights storage + std::vector shared_expert_gate_weights_; + std::vector shared_expert_up_weights_; + std::vector shared_expert_down_weights_; + + // MoE routing tensors + torch::Tensor expert_group_; + torch::Tensor one_hot_; + torch::Tensor zero_hot_; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/onerec_block_layer.h b/xllm/core/layers/onerec_block_layer.h new file mode 100644 index 00000000..c6246ed2 --- /dev/null +++ b/xllm/core/layers/onerec_block_layer.h @@ -0,0 +1,42 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#if defined(USE_NPU) +#include "npu/npu_onerec_block_layer_impl.h" +#endif + +namespace xllm { +namespace layer { + +#if defined(USE_NPU) +class OneRecBlockLayer + : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + using Impl __attribute__((__unused__)) = NpuOneRecBlockLayerImpl; + + OneRecBlockLayer(const ModelContext& context, + bool is_decoder = false, + int layer_id = 0) + : ModuleHolder(std::make_shared(context, + is_decoder, + layer_id)) {} +}; +#endif + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/CMakeLists.txt b/xllm/core/runtime/CMakeLists.txt index 54b10152..ebb43dcd 100644 --- a/xllm/core/runtime/CMakeLists.txt +++ b/xllm/core/runtime/CMakeLists.txt @@ -22,10 +22,12 @@ cc_library( dit_worker.h embed_worker_impl.h embed_vlm_worker_impl.h + rec_worker_impl.h engine.h llm_engine.h vlm_engine.h dit_engine.h + rec_engine.h worker_client.h xservice_client.h speculative_engine.h @@ -40,12 +42,14 @@ cc_library( worker_impl.cpp llm_worker_impl.cpp vlm_worker_impl.cpp + rec_worker_impl.cpp dit_worker.cpp embed_worker_impl.cpp embed_vlm_worker_impl.cpp llm_engine.cpp vlm_engine.cpp dit_engine.cpp + rec_engine.cpp worker_client.cpp xservice_client.cpp params_utils.cpp @@ -88,11 +92,13 @@ cc_library( master.h vlm_master.h dit_master.h + rec_master.h SRCS llm_master.cpp master.cpp vlm_master.cpp dit_master.cpp + rec_master.cpp DEPS :common :distributed_runtime diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index dd4a3d8f..fbce4155 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -37,6 +37,7 @@ class WorkerType { DIT, // DIT ELM, // Embedding LM EVLM, // Embedding VLM + REC, // Rec }; constexpr WorkerType(Value v) : value_(v) {} @@ -51,6 +52,8 @@ class WorkerType { value_ = ELM; } else if (str == "EVLM") { value_ = EVLM; + } else if (str == "REC") { + value_ = REC; } else { value_ = INVALID; } @@ -77,6 +80,8 @@ class WorkerType { return "ELM"; } else if (this->value_ == EVLM) { return "EVLM"; + } else if (this->value_ == REC) { + return "REC"; } else { return "INVALID"; } diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 7ce016a7..9202abc7 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -194,22 +194,6 @@ std::optional LLMWorkerImpl::step( } } - // if running in multi_stream_parallel step, all micro batches - // should be in same prefill stage, so, to judge empty_kv_cache, - // just use micro batch 0 here - if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) { - output.sample_output.embeddings = hidden_states; - } else if (concated_sampling_params.sample_idxes.defined()) { - // auto sample_idxes = - // concated_sampling_params.selected_token_idxes.index_select( - // /*dim=*/0, concated_sampling_params.sample_idxes); - auto embeddings = hidden_states.index_select( - /*dim=*/0, concated_sampling_params.sample_idxes); - output.sample_output.embeddings = embeddings; - } - } - auto ret = device_.synchronize_default_stream(); if (options_.kv_cache_transfer_mode() == "PUSH" && diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp index 0ade6f56..e9555fab 100644 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -34,6 +34,8 @@ limitations under the License. #include "runtime/dit_master.h" #include "runtime/llm_engine.h" #include "runtime/llm_master.h" +#include "runtime/rec_engine.h" +#include "runtime/rec_master.h" #include "runtime/speculative_engine.h" #include "runtime/vlm_engine.h" #include "runtime/vlm_master.h" @@ -210,6 +212,39 @@ Master::Master(const Options& options, EngineType type) : options_(options) { eng_options.device_ip(options_.device_ip().value()); } engine_ = std::make_unique(eng_options); + } else if (type == EngineType::REC) { + runtime::Options eng_options; + eng_options.model_path(options_.model_path()) + .devices(devices) + .block_size(options_.block_size()) + .max_cache_size(options_.max_cache_size()) + .max_memory_utilization(options_.max_memory_utilization()) + .enable_prefix_cache(options_.enable_prefix_cache()) + .task_type(options_.task_type()) + .enable_mla(options_.enable_mla()) + .master_node_addr(options_.master_node_addr()) + .nnodes(options_.nnodes()) + .node_rank(options_.node_rank()) + .dp_size(options_.dp_size()) + .ep_size(options_.ep_size()) + .enable_chunked_prefill(options_.enable_chunked_prefill()) + .max_seqs_per_batch(options_.max_seqs_per_batch()) + .max_tokens_per_chunk_for_prefill( + options_.max_tokens_per_chunk_for_prefill()) + .instance_role(options_.instance_role()) + .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) + .transfer_listen_port(options_.transfer_listen_port()) + .enable_disagg_pd(options_.enable_disagg_pd()) + .enable_service_routing(options_.enable_service_routing()) + .enable_schedule_overlap(options_.enable_schedule_overlap()) + .enable_cache_upload(options_.enable_cache_upload()) + .host_blocks_factor(options_.host_blocks_factor()) + .enable_kvcache_store(options_.enable_kvcache_store()) + .store_protocol(options_.store_protocol()) + .store_master_server_entry(options_.store_master_server_entry()) + .store_metadata_connstring(options_.store_metadata_connstring()) + .enable_continuous_kvcache(options_.enable_continuous_kvcache()); + engine_ = std::make_unique(eng_options); } else { LOG(WARNING) << "Not supported llm engine type: " << static_cast(type); @@ -225,6 +260,8 @@ std::unique_ptr create_master(const std::string& backend, } else if (backend == "dit") { LOG(INFO) << "creating dit master"; return std::make_unique(options); + } else if (backend == "rec") { + return std::make_unique(options); } else { LOG(FATAL) << "Failed to create master, backend is" << backend; return nullptr; diff --git a/xllm/core/runtime/rec_engine.cpp b/xllm/core/runtime/rec_engine.cpp new file mode 100644 index 00000000..7df38788 --- /dev/null +++ b/xllm/core/runtime/rec_engine.cpp @@ -0,0 +1,341 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_engine.h" + +#include + +#include +#include + +#include "common/metrics.h" +#include "framework/model/model_args.h" +#include "framework/model_loader.h" +#include "framework/parallel_state/parallel_state.h" +#include "util/pretty_print.h" +#include "util/timer.h" +#include "util/utils.h" +#include "worker.h" + +namespace xllm { + +RecEngine::RecEngine(const runtime::Options& options) : options_(options) { + const auto& devices = options_.devices(); + CHECK_GT(devices.size(), 0) << "At least one device is required"; + + CHECK(!devices[0].is_cpu()) << "CPU device is not supported"; + const auto device_type = devices[0].type(); + for (const auto device : devices) { + CHECK_EQ(device.type(), device_type) + << "All devices should be the same type"; + } + + // initialize process groups if there are multiple devices + if (devices.size() > 1) { + // create a process group for each device if there are multiple gpus + process_groups_ = parallel_state::create_npu_process_groups(devices); + } + + WorkerType worker_type = WorkerType::REC; + const int32_t world_size = static_cast(devices.size()); + for (size_t i = 0; i < devices.size(); ++i) { + const int32_t rank = static_cast(i); + ProcessGroup* pg = world_size > 1 ? process_groups_[i].get() : nullptr; + ParallelArgs parallel_args(rank, world_size, pg); + workers_.emplace_back(std::make_unique( + parallel_args, devices[i], options_, worker_type)); + } + + if (workers_.size() > 1) { + // test process group + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.emplace_back(worker->process_group_test_async()); + } + // wait up to 4 seconds for all futures to complete + folly::collectAll(futures).within(std::chrono::seconds(4)).get(); + } +} + +bool RecEngine::init() { + if (!init_model()) { + LOG(ERROR) << "Failed to init model from: " << options_.model_path(); + return false; + } + + auto kv_cache_cap = estimate_kv_cache_capacity(); + + if (!allocate_kv_cache(kv_cache_cap)) { + LOG(ERROR) << "Failed to allocate kv cache"; + return false; + } + + return true; +} + +bool RecEngine::init_model() { + const std::string& model_path = options_.model_path(); + auto model_loader = ModelLoader::create(model_path); + LOG(INFO) << "Initializing model from: " << model_path; + + // RecEngine does not use tokenizer + tokenizer_ = model_loader->tokenizer(); + CHECK(tokenizer_ != nullptr); + + args_ = model_loader->model_args(); + quant_args_ = model_loader->quant_args(); + tokenizer_args_ = model_loader->tokenizer_args(); + + // compute the number of local kv heads and head dim + const int world_size = static_cast(workers_.size()); + const int64_t n_heads = args_.n_heads(); + const int64_t n_kv_heads = args_.n_kv_heads().value_or(n_heads); + n_local_kv_heads_ = std::max(1, n_kv_heads / world_size); + head_dim_ = args_.head_dim(); + dtype_ = xllm::util::parse_dtype(args_.dtype(), options_.devices()[0]); + + // key + value for all layers + LOG(INFO) << "Block info, block_size: " << options_.block_size() + << ", n_local_kv_heads: " << n_local_kv_heads_ + << ", head_dim: " << head_dim_ << ", n_layers: " << args_.n_layers() + << ", dtype: " << dtype_; + + // RecEngine does not use tokenizer, skip vocab_size check + + LOG(INFO) << "Initializing model with " << args_; + LOG(INFO) << "Initializing model with quant args: " << quant_args_; + LOG(INFO) << "Initializing model with tokenizer args: " << tokenizer_args_; + + // init model for each worker in parallel + // multiple workers, call async init + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->init_model_async(model_path)); + } + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (const auto& result : results) { + if (!result.value()) { + return false; + } + } + + return true; +} + +Engine::KVCacheCapacity RecEngine::estimate_kv_cache_capacity() { + const int64_t max_cache_size = options_.max_cache_size(); + const double max_memory_utilization = options_.max_memory_utilization(); + + const auto& device = workers_[0]->device(); + // call worker to profile memory usage + std::vector>> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->estimate_kv_cache_capacity_async()); + } + + // pick smallest available memory from all devices + int64_t cache_size_in_bytes = std::numeric_limits::max(); + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (size_t i = 0; i < results.size(); ++i) { + const auto device = workers_[i]->device(); + if (!results[i].hasValue()) { + LOG(ERROR) << "Failed to profile memory usage for device: " << device; + continue; + } + auto [available_memory, total_memory] = results[i].value(); + LOG(INFO) << device + << ": available memory: " << readable_size(available_memory) + << ", total memory: " << readable_size(total_memory) + << ", Using max_memory_utilization: " << max_memory_utilization + << ", max_cache_size: " << readable_size(max_cache_size); + // apply memory cap from config if it is set + if (max_memory_utilization < 1.0) { + const int64_t buffer_memory = + total_memory * (1.0 - max_memory_utilization); + available_memory -= buffer_memory; + } + if (max_cache_size > 0) { + available_memory = std::min(available_memory, max_cache_size); + } + cache_size_in_bytes = std::min(cache_size_in_bytes, available_memory); + } + + KVCacheCapacity kv_cache_cap; + kv_cache_cap.cache_size_in_bytes = std::max(cache_size_in_bytes, int64_t(0)); + CHECK_GT(kv_cache_cap.cache_size_in_bytes, 0) + << "Available kv cache size must be greater than 0"; + + // compute kv cache slot size + const auto dtype_size = torch::scalarTypeToTypeMeta(dtype_).itemsize(); + // key + value for all layers + const int64_t slot_size = + 2 * n_local_kv_heads_ * head_dim_ * args_.n_layers() * dtype_size; + kv_cache_cap.slot_size = slot_size; + + // compute kv blocks num + const int32_t block_size = options_.block_size(); + const int64_t block_size_in_bytes = block_size * slot_size; + kv_cache_cap.n_blocks = cache_size_in_bytes / block_size_in_bytes; + CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; + + return kv_cache_cap; +} + +bool RecEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { + LOG(INFO) << "kv cache capacity: " + << "bytes: " << kv_cache_cap.cache_size_in_bytes + << ", blocks: " << kv_cache_cap.n_blocks + << ", slot_size: " << kv_cache_cap.slot_size; + + const int32_t block_size = options_.block_size(); + + // init kv cache for each worker + std::vector> kv_cache_shape; + kv_cache_shape.reserve(2); + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); + + LOG(INFO) << "Initializing k cache with shape: [" << kv_cache_shape[0] << "]"; + LOG(INFO) << "Initializing v cache with shape: [" << kv_cache_shape[1] << "]"; + + // initialize block manager + BlockManagerPool::Options options; + options.num_blocks(kv_cache_cap.n_blocks) + .host_num_blocks(kv_cache_cap.n_blocks) + .block_size(block_size) + .enable_prefix_cache(options_.enable_prefix_cache()) + .enable_disagg_pd(options_.enable_disagg_pd()) + .enable_cache_upload(options_.enable_cache_upload()); + kv_cache_manager_ = std::make_unique(options); + + // init kv cache for each worker in parallel + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->allocate_kv_cache_async(kv_cache_shape)); + } + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + for (const auto& result : results) { + if (!result.value()) { + return false; + } + } + return true; +} + +// RecEngine executes model three times: prefill + 2 decode steps +ForwardOutput RecEngine::step(std::vector& batches) { + if (workers_.empty()) { + // empty worker, return + return {}; + } + + Timer timer; + auto forward_inputs = workers_[0]->prepare_inputs(batches[0]); + COUNTER_ADD(prepare_input_latency_microseconds, timer.elapsed_microseconds()); + + if (!forward_inputs.token_ids.defined()) { + // empty input, just return + return {}; + } + + timer.reset(); + // Prefill step: Run the first model execution + const auto& prefill_output = get_model_output(forward_inputs); + COUNTER_ADD(rec_first_token_latency_microseconds, + timer.elapsed_microseconds()); + + timer.reset(); + batches[0].process_sample_output(prefill_output.sample_output, false); + COUNTER_ADD(rec_sampling_latency_microseconds, timer.elapsed_microseconds()); + + // Decode steps: Run the model 2 more times for decoding + ForwardOutput decode_output; + + for (int i = 0; i < 2; ++i) { + timer.reset(); + forward_inputs = workers_[0]->prepare_inputs(batches[0]); + COUNTER_ADD(prepare_input_latency_microseconds, + timer.elapsed_microseconds()); + + timer.reset(); + decode_output = get_model_output(forward_inputs); + if (i == 0) { + COUNTER_ADD(rec_second_token_latency_microseconds, + timer.elapsed_microseconds()); + } else if (i == 1) { + COUNTER_ADD(rec_third_token_latency_microseconds, + timer.elapsed_microseconds()); + } + + timer.reset(); + batches[0].process_sample_output(decode_output.sample_output, false); + COUNTER_ADD(rec_sampling_latency_microseconds, + timer.elapsed_microseconds()); + } + + batches[0].finish(); + + // Return the final model output + return decode_output; +} + +void RecEngine::update_last_step_result(std::vector& batch) {} + +std::vector RecEngine::get_active_activation_memory() const { + // call worker to get active activation memory + std::vector> futures; + futures.reserve(workers_.size()); + for (auto& worker : workers_) { + futures.push_back(worker->get_active_activation_memory_async()); + } + + // wait for all futures to complete + auto results = folly::collectAll(futures).get(); + std::vector active_activation_memories; + active_activation_memories.reserve(workers_.size()); + for (auto& result : results) { + active_activation_memories.push_back(result.value()); + } + return active_activation_memories; +} + +ForwardOutput RecEngine::get_model_output(const ForwardInput& model_inputs) { + std::vector>> futures; + futures.reserve(workers_.size()); + // TODO to adapt multi stream parallel later + BatchedForwardInputs batched_fwd_inputs; + batched_fwd_inputs.micro_inputs = {model_inputs}; + for (auto& worker : workers_) { + futures.emplace_back(worker->step_async(batched_fwd_inputs)); + } + // wait for the all future to complete + auto results = folly::collectAll(futures).get(); + // return the result from the driver + auto forward_output = results.front().value(); + + DCHECK(forward_output.has_value()) << "Failed to execute model"; + return forward_output.value(); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_engine.h b/xllm/core/runtime/rec_engine.h new file mode 100644 index 00000000..6b156b06 --- /dev/null +++ b/xllm/core/runtime/rec_engine.h @@ -0,0 +1,80 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "common/macros.h" +#include "engine.h" +#include "framework/batch/batch.h" +#include "framework/block/block_manager_pool.h" +#include "framework/quant_args.h" +#include "framework/tokenizer/tokenizer.h" +#include "framework/tokenizer/tokenizer_args.h" +#include "worker.h" + +namespace xllm { + +class RecEngine : public Engine { + public: + // create an engine with the given devices + RecEngine(const runtime::Options& options); + + virtual ~RecEngine() = default; + + ForwardOutput step(std::vector& batch) override; + + const runtime::Options& options() const { return options_; } + + bool init() override; + + void update_last_step_result(std::vector& batch) override; + + // return the active activation memory + std::vector get_active_activation_memory() const override; + + private: + bool init_model(); + Engine::KVCacheCapacity estimate_kv_cache_capacity(); + bool allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap); + + // Helper methods for rec-specific execution + ForwardOutput get_model_output(const ForwardInput& model_inputs); + + private: + // options + runtime::Options options_; + + // dtype + torch::ScalarType dtype_; + + // quantization args + QuantArgs quant_args_; + + // a list of process groups, with each process group handling a single device + std::vector> process_groups_; + + // a list of workers, with each worker handling a partial of model + std::vector> workers_; + + // config for kv cache + int64_t n_local_kv_heads_ = 0; + int64_t head_dim_ = 0; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_master.cpp b/xllm/core/runtime/rec_master.cpp new file mode 100644 index 00000000..748593e6 --- /dev/null +++ b/xllm/core/runtime/rec_master.cpp @@ -0,0 +1,268 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_master.h" + +#include +#include +#include + +#include "absl/time/time.h" +#include "models/model_registry.h" +#include "runtime/rec_engine.h" +#include "runtime/xservice_client.h" +#include "scheduler/scheduler_factory.h" +#include "util/scope_guard.h" +#include "util/threadpool.h" +#include "util/utils.h" + +namespace xllm { + +RecMaster::RecMaster(const Options& options) + : Master(options, EngineType::REC) { + // Initialize with Rec engine type + // The rest of the initialization follows the same pattern as LLMMaster + CHECK(engine_->init()); + + model_args_ = engine_->model_args(); + + bool enable_decode_response_to_service = false; + if (options_.enable_service_routing()) { + XServiceClient* xservice_client = XServiceClient::get_instance(); + if (!xservice_client->init(options_.etcd_addr().value_or(""), + options_.xservice_addr().value_or(""), + options_.instance_name().value_or(""), + engine_->block_manager_pool())) { + LOG(FATAL) << "XServiceClient init fail!"; + return; + } + auto service_config = xservice_client->get_config(); + enable_decode_response_to_service = + service_config.enable_decode_response_to_service; + } + + ContinuousScheduler::Options scheduler_options; + scheduler_options.max_tokens_per_batch(options_.max_tokens_per_batch()) + .max_seqs_per_batch(options_.max_seqs_per_batch()) + .max_tokens_per_chunk_for_prefill( + options_.max_tokens_per_chunk_for_prefill()) + .num_speculative_tokens(options_.num_speculative_tokens()) + .dp_size(options_.dp_size()) + .enable_disagg_pd(options_.enable_disagg_pd()) + .enable_schedule_overlap(options_.enable_schedule_overlap()) + .enable_chunked_prefill(options_.enable_chunked_prefill()) + .instance_role(options_.instance_role()) + .kv_cache_transfer_mode(options_.kv_cache_transfer_mode()) + .enable_service_routing(options_.enable_service_routing()) + .enable_decode_response_to_service(enable_decode_response_to_service); + scheduler_ = create_fixsteps_scheduler(engine_.get(), scheduler_options); + + // OmniRec model does not have a tokenizer + chat_template_ = nullptr; + tokenizer_ = nullptr; + threadpool_ = + std::make_unique(options_.num_request_handling_threads()); +} + +void RecMaster::run() { + const bool already_running = running_.load(std::memory_order_relaxed); + if (already_running) { + LOG(WARNING) << "RecMaster is already running."; + return; + } + running_.store(true, std::memory_order_relaxed); + loop_thread_ = std::thread([this]() { + const auto timeout = absl::Milliseconds(5); + while (!stopped_.load(std::memory_order_relaxed)) { + // move scheduler forward + scheduler_->step(timeout); + } + running_.store(false, std::memory_order_relaxed); + }); + + // Engine run method is not available, remove this call +} + +RecMaster::~RecMaster() { + // set stop flag + stopped_.store(true, std::memory_order_relaxed); + // wait for the loop thread to finish + if (loop_thread_.joinable()) { + loop_thread_.join(); + } +} + +void RecMaster::handle_request(std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback) { + // add one pending request + scheduler_->incr_pending_requests(1); + auto cb = [callback = std::move(callback), + scheduler = scheduler_.get()](const RequestOutput& output) { + output.log_request_status(); + return callback(output); + }; + // add into the queue + threadpool_->schedule([this, + prompt = std::move(prompt), + prompt_tokens = std::move(prompt_tokens), + mm_data = std::move(mm_data), + sp = std::move(sp), + callback = std::move(cb)]() mutable { + AUTO_COUNTER(request_handling_latency_seconds_completion); + + // remove the pending request after scheduling + SCOPE_GUARD([this] { scheduler_->decr_pending_requests(); }); + + Timer timer; + // verify the prompt + if (!sp.verify_params(callback)) { + return; + } + + auto request = generate_request(std::move(prompt), + std::move(prompt_tokens), + std::move(mm_data), + sp, + callback); + if (!request) { + return; + } + + if (!scheduler_->add_request(request)) { + CALLBACK_WITH_ERROR(StatusCode::RESOURCE_EXHAUSTED, + "No available resources to schedule request"); + } + }); +} + +std::shared_ptr RecMaster::generate_request( + std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback) { + // For Rec model, prompt is expected to be empty and prompt_tokens should + // contain the actual data Skip prompt empty check as mentioned in + // requirements + + Timer timer; + std::vector local_prompt_tokens; + + if (prompt_tokens.has_value()) { + local_prompt_tokens = std::move(prompt_tokens.value()); + LOG(INFO) + << "[Rec DEBUG] generate_request - received prompt_tokens.size(): " + << local_prompt_tokens.size() + << ", prompt.length(): " << prompt.length(); + } else if (!mm_data.has_value()) { + // sparse LLM + LOG(ERROR) << "Rec model requires prompt_tokens/embedding to be provided"; + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "Rec model requires prompt_tokens/embedding to be provided"); + return nullptr; + } + + COUNTER_ADD(tokenization_latency_seconds, timer.elapsed_seconds()); + + int32_t max_context_len = model_args_.max_position_embeddings(); + if (!options_.enable_chunked_prefill()) { + max_context_len = + std::min(max_context_len, options_.max_tokens_per_batch()); + } + if (local_prompt_tokens.size() >= max_context_len) { + LOG(ERROR) << "Prompt is too long: " << local_prompt_tokens.size(); + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Prompt is too long"); + return nullptr; + } + + uint32_t max_tokens = sp.max_tokens; + if (max_tokens == 0) { + const uint32_t kDefaultMaxTokens = 5120; + max_tokens = kDefaultMaxTokens; + } + + // allocate enough capacity for prompt tokens, max tokens, and speculative + // tokens + size_t capacity = local_prompt_tokens.size() + max_tokens + + options_.num_speculative_tokens() + /*bonus_token*/ 1; + if (options_.enable_schedule_overlap()) { + capacity += options_.num_speculative_tokens() + 1; + } + const size_t best_of = sp.best_of.value_or(sp.n); + + RequestSamplingParam sampling_param; + sampling_param.frequency_penalty = sp.frequency_penalty; + sampling_param.presence_penalty = sp.presence_penalty; + sampling_param.repetition_penalty = sp.repetition_penalty; + sampling_param.temperature = sp.temperature; + sampling_param.top_p = sp.top_p; + sampling_param.top_k = sp.top_k; + sampling_param.logprobs = sp.logprobs; + sampling_param.top_logprobs = sp.top_logprobs; + sampling_param.is_embeddings = sp.is_embeddings; + sampling_param.beam_width = sp.beam_width; + if (best_of > sp.n) { + // enable logprobs for best_of to generate sequence logprob + sampling_param.logprobs = true; + } + // sampling_param.do_sample = sp.do_sample; + + bool stream = sp.streaming; + // results cannot be streamed when best_of != n + if (best_of != sp.n) { + stream = false; + } + // std::unordered_set stop_tokens; + // std::vector> stop_sequences; + // StoppingChecker stopping_checker( + // max_tokens, + // max_context_len - options_.num_speculative_tokens(), + // , + // model_args_.eos_token_id(), + // sp.ignore_eos, + // std::move(stop_tokens), + // std::move(stop_sequences)); + StoppingChecker stopping_checker; + RequestState req_state(std::move(prompt), + std::move(local_prompt_tokens), + mm_data.value_or(MMData{}), + std::move(sampling_param), + std::move(stopping_checker), + capacity, + sp.n, + best_of, + sp.logprobs, + stream, + sp.echo, + sp.skip_special_tokens, + options_.enable_schedule_overlap(), + callback, + nullptr, + sp.decode_address); + req_state.is_rec_model = true; + req_state.bos_token_id = model_args_.bos_token_id(); + auto request = std::make_shared(sp.request_id, + sp.x_request_id, + sp.x_request_time, + std::move(req_state), + sp.service_request_id); + return request; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_master.h b/xllm/core/runtime/rec_master.h new file mode 100644 index 00000000..60d20c42 --- /dev/null +++ b/xllm/core/runtime/rec_master.h @@ -0,0 +1,71 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "framework/chat_template/jinja_chat_template.h" +#include "framework/model/model_args.h" +#include "runtime/master.h" +#include "runtime/rec_engine.h" +#include "scheduler/continuous_scheduler.h" +#include "scheduler/fixsteps_scheduler.h" +#include "util/threadpool.h" + +namespace xllm { + +class RecMaster : public Master { + public: + explicit RecMaster(const Options& options); + ~RecMaster(); + + // handle a request, the engine will execute the request asynchronously + // completion/encode + void handle_request(std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback); + + // start the handling loop + void run() override; + + private: + std::shared_ptr generate_request( + std::string prompt, + std::optional> prompt_tokens, + std::optional mm_data, + RequestParams sp, + OutputCallback callback); + + std::unique_ptr scheduler_; + // model args + ModelArgs model_args_; + std::unique_ptr threadpool_; + std::unique_ptr tokenizer_; + // chat template instance + std::unique_ptr chat_template_; + // thread for moving forward the scheduler + std::thread loop_thread_; + // flag to stop the loop + std::atomic stopped_{false}; + + // flag to indicate if the handler is running + std::atomic running_{false}; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_worker_impl.cpp b/xllm/core/runtime/rec_worker_impl.cpp new file mode 100644 index 00000000..b930ef40 --- /dev/null +++ b/xllm/core/runtime/rec_worker_impl.cpp @@ -0,0 +1,363 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_worker_impl.h" + +#include + +#include +#include +#include + +#include "butil/file_util.h" +#include "butil/files/dir_reader_linux.h" +#include "butil/files/file_path.h" +#include "butil/strings/string_util.h" +#include "common/metrics.h" +#include "models/model_registry.h" +#include "util/env_var.h" +#include "util/utils.h" + +namespace xllm { + +RecWorkerImpl::RecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options) + : WorkerImpl(parallel_args, device, options) { + // Initialize filter mask stream for H2D operations + filter_mask_stream_ = device_.get_stream_from_pool(); + + // Initialize thread pool for async operations using environment variable + int thread_num = util::get_int_env(util::EXTRA_THREAD_NUM, 16); + thread_pool_ = std::make_shared(thread_num); +} + +bool RecWorkerImpl::init_model(const std::string& model_weights_path) { + auto model_loader = ModelLoader::create(model_weights_path); + + auto args = model_loader->model_args(); + auto quant_args = model_loader->quant_args(); + torch::ScalarType dtype = util::parse_dtype(args.dtype(), device_); + + if (options_.enable_speculative_decode() && FLAGS_enable_atb_spec_kernel) { + args.num_speculative_tokens(options_.num_speculative_tokens()); + } + + // create model context + dtype_ = dtype; + auto tensor_options = torch::dtype(dtype_).device(device_); + context_ = ModelContext(parallel_args_, args, quant_args, tensor_options); + + // init model, create model executor + bool status = this->init_model(context_); + if (!status) { + return false; + } + + this->load_model(std::move(model_loader)); + + status_ = Status::LOADED; + // TODO: replace path with flags after filter merge + butil::FilePath filter_bin_path = + butil::FilePath(model_weights_path).Append("replace me when merge"); + valid_path_filter_ = std::make_unique( + filter_bin_path.value(), args.vocab_size(), dtype_, device_); + + return true; +} + +bool RecWorkerImpl::init_model(ModelContext& context) { + CHECK(model_ == nullptr) << "Model is already initialized."; + device_.set_device(); + + // Try to create a causal LM model (Rec models are typically based on + // CausalLM) + model_ = create_llm_model(context); + + // Check if model creation was successful + CHECK(model_ != nullptr) << "Failed to create Rec model."; + model_executor_ = std::make_unique( + model_.get(), context.get_model_args(), device_, options_); + + if (FLAGS_enable_beam_search_kernel) { + beam_searcher_ = std::make_unique(); + } + return true; +} + +std::optional RecWorkerImpl::step( + const BatchedForwardInputs& inputs) { + device_.set_device(); + + // Timer for performance monitoring + auto start_time = std::chrono::high_resolution_clock::now(); + + std::vector flatten_tokens_micro_batches; + std::vector flatten_positions_micro_batches; + std::vector input_params_micro_batches; + auto& concated_sampling_params = inputs.concated_sampling_params; + + for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { + flatten_tokens_micro_batches.push_back( + std::move(inputs.micro_inputs[i].token_ids)); + flatten_positions_micro_batches.push_back( + std::move(inputs.micro_inputs[i].positions)); + input_params_micro_batches.push_back( + std::move(inputs.micro_inputs[i].input_params)); + } + auto sampling_params = inputs.micro_inputs[0].sampling_params; + // Start async filter mask preparation early for overlap (if beam search is + // enabled) + std::future filter_mask_future; + + if (!input_params_micro_batches.empty() && + input_params_micro_batches[0].is_rec_model() && + input_params_micro_batches[0].rec_params.has_value()) { + auto& rec_params = input_params_micro_batches[0].rec_params.value(); + if (!rec_params.generated_tokens.empty()) { + filter_mask_future = + prepare_filter_mask_async(rec_params.generated_tokens); + } + } + + // Check if we have encoder inputs (rec model with encoder/decoder) + torch::Tensor hidden_states; + bool has_encoder_inputs = false; + + // Check if this is a rec model with encoder inputs + if (!input_params_micro_batches.empty() && + input_params_micro_batches[0].is_rec_model() && + input_params_micro_batches[0].rec_params.has_value()) { + auto& rec_params = input_params_micro_batches[0].rec_params.value(); + + // Check for encoder inputs + if ((rec_params.encoder_token_ids.defined() && + rec_params.encoder_positions.defined()) || + rec_params.encoder_sparse_embedding.defined()) { + has_encoder_inputs = true; + + // Set hybrid mode if sparse embedding is defined + if (rec_params.encoder_sparse_embedding.defined()) { + input_params_micro_batches[0].rec_params->is_hybrid_mode = true; + } + } + } + + // Two-stage forward: encoder then decoder + auto& rec_params = input_params_micro_batches[0].rec_params.value(); + + if (rec_params.rec_stage == RecModelInputParams::RecStage::PREFILL) { + // Check if this is the first prefill or subsequent prefill + if (!rec_params.is_first_prefill) { + // Subsequent prefill: only run decoder + input_params_micro_batches[0].rec_params->is_encoder_forward = false; + hidden_states = model_executor_->forward(flatten_tokens_micro_batches, + flatten_positions_micro_batches, + kv_caches_, + input_params_micro_batches); + } else if (has_encoder_inputs) { + // First prefill: run encoder first, then decoder + + // 1. Run encoder forward + auto encoder_input_params = input_params_micro_batches; + encoder_input_params[0].rec_params->is_encoder_forward = true; + + std::vector encoder_tokens; + std::vector encoder_positions; + + if (rec_params.is_hybrid_mode && + rec_params.encoder_sparse_embedding.defined()) { + encoder_tokens.push_back(rec_params.encoder_sparse_embedding); + } else { + encoder_tokens.push_back(rec_params.encoder_token_ids); + } + encoder_positions.push_back(rec_params.encoder_positions); + + // Run encoder + hidden_states = model_executor_->forward( + encoder_tokens, encoder_positions, kv_caches_, encoder_input_params); + + // 2. Run decoder forward + encoder_input_params[0].rec_params->is_encoder_forward = false; + hidden_states = model_executor_->forward(flatten_tokens_micro_batches, + flatten_positions_micro_batches, + kv_caches_, + encoder_input_params); + + } else { + // Non-rec model or rec model without encoder: use standard forward + LOG(ERROR) << "RecWorkerImpl not supports decoder-only model now."; + } + } else { + // Decode stage: only run decoder, not used now. + hidden_states = model_executor_->forward(flatten_tokens_micro_batches, + flatten_positions_micro_batches, + kv_caches_, + input_params_micro_batches); + } + + torch::Tensor logits; + if (sampling_params.selected_token_idxes.defined()) { + logits = + model_->logits(hidden_states, sampling_params.selected_token_idxes); + } + + ForwardOutput output; + + if (!driver_) { + return std::nullopt; + } + + // Get filter mask result from async preparation if available + torch::Tensor filter_mask; + if (filter_mask_future.valid()) { + // Get the result from async preparation (this will block if not ready) + filter_mask = filter_mask_future.get(); + } + + // Driver prepare model output + + if (sampling_params.selected_token_idxes.defined()) { + // auto sample_logits = + // logits.index_select(/*dim=*/0, + // concated_sampling_params.sample_idxes); + + // Apply filter mask if available + // TODO: fix filter + // if (filter_mask.defined()) { + // // Ensure filter_mask has the same batch size as sample_logits + // if (filter_mask.size(0) == sample_logits.size(0)) { + // sample_logits = sample_logits + filter_mask; + // } else { + // // If dimensions don't match, select the appropriate rows from + // // filter_mask + // auto selected_filter_mask = filter_mask.index_select( + // /*dim=*/0, concated_sampling_params.sample_idxes); + // sample_logits = sample_logits + selected_filter_mask; + // } + // } + + auto sample_output = sampler_->forward(logits, sampling_params); + output.logits = logits; + + // Set sample output to output + output.sample_output = sample_output; + + // Carry over the sampling params + output.do_sample = sampling_params.do_sample; + output.logprobs = sampling_params.logprobs; + output.max_top_logprobs = sampling_params.max_top_logprobs; + } + + // Transfer sample output tensors to CPU for batch.cpp access + if (output.sample_output.next_tokens.defined()) { + output.sample_output.next_tokens = + safe_to(output.sample_output.next_tokens, torch::kCPU, true); + } + if (output.sample_output.logprobs.defined()) { + output.sample_output.logprobs = + safe_to(output.sample_output.logprobs, torch::kCPU, true); + } + if (output.sample_output.top_tokens.defined()) { + output.sample_output.top_tokens = + safe_to(output.sample_output.top_tokens, torch::kCPU, true); + } + if (output.sample_output.top_logprobs.defined()) { + output.sample_output.top_logprobs = + safe_to(output.sample_output.top_logprobs, torch::kCPU, true); + } + + // Synchronize at the end like in llm_worker_impl + auto ret = device_.synchronize_default_stream(); + + // Record execution latency + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time); + COUNTER_ADD(execution_latency_seconds_model, duration.count() / 1000000.0); + + return output; +} + +ForwardInput RecWorkerImpl::prepare_inputs(Batch& batch) { + // Use the rec-specific input preparation method + return batch.prepare_rec_forward_input(options_.num_decoding_tokens(), + 0, // min_decoding_batch_size + context_.get_model_args()); +} + +std::future RecWorkerImpl::prepare_filter_mask_async( + const std::vector>& generated_tokens) { + // Create promise/future pair for async result + auto promise = std::make_shared>(); + auto future = promise->get_future(); + + // Submit async task to thread pool + thread_pool_->schedule([this, generated_tokens, promise]() -> void { + try { + // Set stream guard for H2D operations + c10::StreamGuard streamGuard = filter_mask_stream_->set_stream_guard(); + + torch::Tensor cpu_mask; + + // Use ValidPathFilter if available, otherwise create placeholder mask + if (valid_path_filter_ && !generated_tokens.empty()) { + // Use ValidPathFilter to generate the actual filter mask + cpu_mask = valid_path_filter_->forward(generated_tokens); + + // If ValidPathFilter returns empty tensor, create placeholder + if (!cpu_mask.defined()) { + int batch_size = generated_tokens.size(); + int vocab_size = 32000; // Default vocab size + cpu_mask = torch::zeros({batch_size, vocab_size}, torch::kFloat32); + } + } else if (!generated_tokens.empty()) { + // Fallback: create placeholder mask when ValidPathFilter is not + // available + int batch_size = generated_tokens.size(); + int vocab_size = 32000; // Default vocab size + cpu_mask = torch::zeros({batch_size, vocab_size}, torch::kFloat32); + + // Apply some basic filtering logic (placeholder) + for (int i = 0; i < batch_size; ++i) { + // Set some tokens to -inf to filter them out + cpu_mask[i] + .slice(0, 0, 1000) + .fill_(-std::numeric_limits::infinity()); + } + } else { + // Return empty tensor if no generated tokens + promise->set_value(torch::Tensor()); + return; + } + + // Copy to device using the dedicated H2D stream + torch::Tensor device_mask = cpu_mask.to(device_, /*non_blocking=*/true); + + // Synchronize the H2D stream to ensure copy is complete + filter_mask_stream_->synchronize(); + + // Set the result in the promise + promise->set_value(device_mask); + } catch (const std::exception& e) { + // Set exception in promise if something goes wrong + promise->set_exception(std::current_exception()); + } + }); + + return future; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/rec_worker_impl.h b/xllm/core/runtime/rec_worker_impl.h new file mode 100644 index 00000000..9c5739fc --- /dev/null +++ b/xllm/core/runtime/rec_worker_impl.h @@ -0,0 +1,76 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include + +#include "framework/batch/batch.h" +#include "framework/model_context.h" +#include "framework/sampling/valid_path_filter.h" +#include "platform/stream.h" +#include "runtime/forward_params.h" +#include "util/threadpool.h" +#include "worker_impl.h" + +namespace xllm { + +// Rec specific worker implementation +class RecWorkerImpl : public WorkerImpl { + public: + RecWorkerImpl(const ParallelArgs& parallel_args, + const torch::Device& device, + const runtime::Options& options); + + // Override init_model for Rec specific implementation + bool init_model(const std::string& model_weights_path) override; + + // Override init_model with ModelContext for Rec specific implementation + bool init_model(ModelContext& context) override; + + // Override step for Rec specific implementation + std::optional step( + const BatchedForwardInputs& inputs) override; + + // Override prepare_inputs for Rec specific implementation + ForwardInput prepare_inputs(Batch& batch) override; + + private: + // Helper method for filter mask preparation (placeholder for future + // implementation) + torch::Tensor prepare_filter_mask( + const std::vector>& generated_tokens); + + // Async filter mask preparation with overlap + std::future prepare_filter_mask_async( + const std::vector>& generated_tokens); + + // Stream for H2D memory copy operations + std::unique_ptr filter_mask_stream_; + + // ThreadPool for async operations + std::shared_ptr thread_pool_; + + // ValidPathFilter for beam search filtering + std::unique_ptr valid_path_filter_; + + // BeamSearcher for beam search functionality + std::unique_ptr beam_searcher_; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/worker.cpp b/xllm/core/runtime/worker.cpp index 3ab9b6e6..b641ebe8 100644 --- a/xllm/core/runtime/worker.cpp +++ b/xllm/core/runtime/worker.cpp @@ -32,6 +32,7 @@ limitations under the License. #include "runtime/embed_vlm_worker_impl.h" #include "runtime/embed_worker_impl.h" #include "runtime/llm_worker_impl.h" +#include "runtime/rec_worker_impl.h" #include "runtime/speculative_worker_impl.h" #include "runtime/vlm_worker_impl.h" #include "util/timer.h" @@ -51,6 +52,8 @@ Worker::Worker(const ParallelArgs& parallel_args, impl_ = new EmbedWorkerImpl(parallel_args, device, options); } else if (worker_type == WorkerType::EVLM) { impl_ = new EmbedVLMWorkerImpl(parallel_args, device, options); + } else if (worker_type == WorkerType::REC) { + impl_ = new RecWorkerImpl(parallel_args, device, options); } else { LOG(ERROR) << "Unknown worker type, please check logic"; } diff --git a/xllm/core/scheduler/CMakeLists.txt b/xllm/core/scheduler/CMakeLists.txt index d694b3b1..999c6ce1 100644 --- a/xllm/core/scheduler/CMakeLists.txt +++ b/xllm/core/scheduler/CMakeLists.txt @@ -17,6 +17,7 @@ cc_library( scheduler.h dit_scheduler.h prefill_only_scheduler.h + fixsteps_scheduler.h scheduler_factory.h decode_priority_queue.h perf_model.h @@ -27,6 +28,7 @@ cc_library( disagg_pd_scheduler.cpp pd_ooc_scheduler.cpp async_response_processor.cpp + fixsteps_scheduler.cpp dit_scheduler.cpp prefill_only_scheduler.cpp scheduler_factory.cpp diff --git a/xllm/core/scheduler/fixsteps_scheduler.cpp b/xllm/core/scheduler/fixsteps_scheduler.cpp new file mode 100644 index 00000000..ccd076ee --- /dev/null +++ b/xllm/core/scheduler/fixsteps_scheduler.cpp @@ -0,0 +1,309 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "fixsteps_scheduler.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common/metrics.h" +#include "framework/batch/batch.h" +#include "framework/batch/batch_factory.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" +#include "runtime/engine.h" + +namespace xllm { + +namespace { +constexpr size_t kRequestQueueSize = 100000; +} // namespace + +FixStepsScheduler::FixStepsScheduler(Engine* engine, const Options& options) + : ContinuousScheduler(engine, options) {} + +bool FixStepsScheduler::add_request(std::shared_ptr& request) { + CHECK(request != nullptr); + CHECK(!request->sequences().empty()); + + if (request_queue_.write(request)) { //.get() + // take over the ownership of the request + // request.release(); + return true; + } + // queue is full + return false; +} + +void FixStepsScheduler::handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + std::vector>& finished_requests) { + // Handle new request prompt first. + // Include those requests that are preempted by others. + // + // schedule the prefill requests in the waiting priority queue until budgets + // are exhausted. + // When the KV Cache usage reaches the threshold, prefill requests will no + // longer be scheduled to avoid frequent preemption. + // + // NOTE: preempted requests will be pushed in waiting_priority_queue, + // they may contian many sequences, so we should check here. + bool budget_exhausted = false; + bool blocks_exhausted = false; + while (!waiting_priority_queue_.empty() && remaining_seq_budget > 0 && + remaining_token_budget > 0 && + kv_cache_manager_->kv_cache_utilization() < + FLAGS_prefill_scheduling_memory_usage_threshold) { + std::shared_ptr request(waiting_priority_queue_.top()); + if (request->finished() || request->cancelled()) { + // kv_cache_manager_->deallocate(request.get()); + // release the ownership of the request + finished_requests.emplace_back(request); + // remove the request from the priority queue + waiting_priority_queue_.pop(); + continue; + } + + const size_t num_sequences = request->sequences().size(); + if (!request->preempted()) { + CHECK(num_sequences == 1) + << "Waiting request should have only one sequence."; + } + + // TODO: FIXME later + // Optimization of the scheduling algorithm under multiple sequences + size_t allocated_tokens = 0; + size_t allocated_seqs = 0; + double allocated_estimate_latency = 0; + bool can_schedule = true; + std::vector prefill_sequences; + std::vector prefill_sequences_budget; + prefill_sequences.reserve(request->sequences().size()); + prefill_sequences_budget.reserve(request->sequences().size()); + for (auto& prefill_sequence : request->sequences()) { + if (prefill_sequence->finished()) { + continue; + } + + size_t num_tokens = prefill_sequence->num_need_compute_tokens(); + if (remaining_token_budget < allocated_tokens + num_tokens || + remaining_seq_budget < allocated_seqs + 1) { + can_schedule = false; + budget_exhausted = true; + break; + } + + prefill_sequences_budget.emplace_back(num_tokens); + prefill_sequences.emplace_back(prefill_sequence.get()); + allocated_tokens += num_tokens; + allocated_seqs += 1; + } + + if (!can_schedule) { + for (auto& seq : prefill_sequences) { + // release shared blocks + kv_cache_manager_->deallocate(seq); + } + break; + } + + if (prefill_sequences.empty()) { + continue; + } + + remaining_token_budget -= allocated_tokens; + remaining_seq_budget -= allocated_seqs; + waiting_priority_queue_.pop(); + running_requests_.emplace_back(request); + running_sequences_.insert(running_sequences_.end(), + prefill_sequences.begin(), + prefill_sequences.end()); + running_sequences_budgets_.insert(running_sequences_budgets_.end(), + prefill_sequences_budget.begin(), + prefill_sequences_budget.end()); + } + + if (running_sequences_.empty() && !waiting_priority_queue_.empty() && + running_queue_->empty()) { + LOG(ERROR) + << "Request prompt is too long, no enough budget/memory to schedule " + "a single sequence."; + // no enough memory to schedule single sequence, just finish the request + std::shared_ptr request(waiting_priority_queue_.top()); + waiting_priority_queue_.pop(); + // block_manager_->release_blocks_for(request.get()); + response_processor_->process_failed_request( + request, + {StatusCode::RESOURCE_EXHAUSTED, + "No enough budget to schedule single sequence."}); + } +} + +std::vector FixStepsScheduler::prepare_batch() { + Timer timer; + // propogate new requests to waiting_priority_queue_ + // Include those requests that are preempted by others. + std::shared_ptr request; + // read from request queue then push to waiting priority queue + while (request_queue_.read(request)) { + CHECK(request); + + // expand sequences to the target number if prefix cache is disabled. + if (!enable_prefix_cache_) { + // expand sequences to the target number + request->expand_sequences(false); + } + + if (request->sequences()[0]->kv_state().kv_cache_tokens_num() == 0) { + waiting_priority_queue_.push(request); + } else { + // request from prefill instance in disagge pd mode. + running_requests_.emplace_back(request); + } + } + + // handle finished/cancelled requests + std::vector> finished_requests; + for (auto it = running_requests_.rbegin(); it != running_requests_.rend(); + ++it) { + if (*it == nullptr) { + continue; + } + std::shared_ptr request = *it; + request->update_connection_status(); + if (request->finished() || request->cancelled()) { + // kv_cache_manager_->deallocate(request.get()); + // release the ownership of the request + finished_requests.emplace_back(request); + // finished request is set to nullptr + *it = nullptr; + } + } + + // clear previous batch + running_requests_.clear(); + running_sequences_.clear(); + running_sequences_budgets_.clear(); + + // remaining budget for the current batch + size_t remaining_token_budget = options_.max_tokens_per_batch(); + size_t remaining_seq_budget = std::max(options_.max_seqs_per_batch(), 1); + size_t num_preempted_requests = 0; + + handle_prefill_requests( + remaining_token_budget, remaining_seq_budget, finished_requests); + + // only forward once, no decode requests + // handle_decode_requests( + // remaining_token_budget, remaining_seq_budget, num_preempted_requests); + + if (!finished_requests.empty()) { + response_processor_->process_completed_requests(finished_requests); + } + + // update the batch + auto batches = BatchFactory::get_instance(options_.dp_size()) + ->create_rec_batches( + running_requests_, + running_sequences_, + running_sequences_budgets_, + kv_cache_manager_->get_copy_in_cache_block_infos(), + kv_cache_manager_->get_copy_out_cache_block_infos(), + kv_cache_manager_->get_swap_cache_block_infos()); + + // update metrics before returning + if (!batches[0].empty()) { + // only update the scheduling latency when there are requests to process + COUNTER_ADD(scheduling_latency_seconds, timer.elapsed_seconds()); + } + + GAUGE_SET(num_pending_requests, + pending_requests_.load(std::memory_order_relaxed)); + GAUGE_SET(num_running_requests, running_requests_.size()); + GAUGE_SET(num_waiting_requests, + waiting_priority_queue_.size() + running_queue_->size()); + + GAUGE_ADD(num_preempted_requests, num_preempted_requests); + + GAUGE_SET(num_running_sequences, running_sequences_.size()); + + GAUGE_SET(kv_cache_utilization_perc, + kv_cache_manager_->kv_cache_utilization()); + if (!FLAGS_enable_continuous_kvcache) { + GAUGE_SET(num_blocks_in_prefix_cache, + kv_cache_manager_->num_blocks_in_prefix_cache().size()); + GAUGE_SET(num_free_blocks, kv_cache_manager_->num_free_blocks().size()); + GAUGE_SET(num_used_blocks, kv_cache_manager_->num_used_blocks().size()); + } + return batches; +} + +std::vector FixStepsScheduler::schedule_request( + const absl::Duration& timeout) { + const auto deadline = absl::Now() + timeout; + std::vector batch; + while (true) { + batch = prepare_batch(); + bool all_empty = + std::all_of(batch.begin(), batch.end(), [](const Batch& one_batch) { + return one_batch.empty(); + }); + if (!all_empty) { + return batch; + } + const auto now = absl::Now(); + if (now > deadline) { + break; + } + // wait for new requests to arrive + constexpr uint64_t kStepSleepTimeMs = 1; + const auto time_to_sleep = + std::min(absl::Milliseconds(kStepSleepTimeMs), deadline - now); + absl::SleepFor(time_to_sleep); + } + // return an empty batch + return batch; +} + +// step the scheduler forward by one step +// may get blocked if there are no requests to process +void FixStepsScheduler::step(const absl::Duration& timeout) { + if (!options_.enable_schedule_overlap()) { + // get a new batch of requests + std::vector batch = schedule_request(timeout); + bool all_empty = + std::all_of(batch.begin(), batch.end(), [](const Batch& one_batch) { + return one_batch.empty(); + }); + if (all_empty) { + return; + } + engine_->step(batch); + kv_cache_manager_->reset_copy_content(); + } else { + LOG(ERROR) << "FixStepsScheduler::step() not supported with " + "enable_schedule_overlap"; + } +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/scheduler/fixsteps_scheduler.h b/xllm/core/scheduler/fixsteps_scheduler.h new file mode 100644 index 00000000..1fcfc3b3 --- /dev/null +++ b/xllm/core/scheduler/fixsteps_scheduler.h @@ -0,0 +1,62 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "async_response_processor.h" +#include "common/macros.h" +#include "common/types.h" +#include "framework/batch/batch.h" +#include "framework/request/request.h" +#include "framework/request/sequence.h" +#include "runtime/xservice_client.h" +#include "scheduler.h" +#include "scheduler/continuous_scheduler.h" + +namespace xllm { +class Engine; + +class FixStepsScheduler final : public ContinuousScheduler { + public: + FixStepsScheduler(Engine* engine, const Options& options); + virtual ~FixStepsScheduler() = default; + + bool add_request(std::shared_ptr& request) override; + + // step the scheduler forward by one step + // may get blocked if there are no requests to process + void step(const absl::Duration& timeout) override; + + private: + std::vector schedule_request(const absl::Duration& timeout); + + // build a batch of requests from the priority queue + virtual std::vector prepare_batch(); + + void handle_prefill_requests( + size_t& remaining_token_budget, + size_t& remaining_seq_budget, + std::vector>& finished_requests); +}; + +} // namespace xllm diff --git a/xllm/core/scheduler/scheduler_factory.cpp b/xllm/core/scheduler/scheduler_factory.cpp index 8be5a8b8..de85bd13 100644 --- a/xllm/core/scheduler/scheduler_factory.cpp +++ b/xllm/core/scheduler/scheduler_factory.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "scheduler/continuous_scheduler.h" #include "scheduler/disagg_pd_scheduler.h" #include "scheduler/dit_scheduler.h" +#include "scheduler/fixsteps_scheduler.h" #include "scheduler/pd_ooc_scheduler.h" #include "scheduler/prefill_only_scheduler.h" #include "scheduler/zero_eviction_scheduler.h" @@ -51,6 +52,12 @@ std::unique_ptr create_continuous_scheduler( return std::make_unique(engine, options); } +std::unique_ptr create_fixsteps_scheduler( + Engine* engine, + ContinuousScheduler::Options options) { + return std::make_unique(engine, options); +} + std::unique_ptr create_dit_scheduler( DiTEngine* engine, DiTScheduler::Options options) { diff --git a/xllm/core/scheduler/scheduler_factory.h b/xllm/core/scheduler/scheduler_factory.h index daf153ba..0fa452dd 100644 --- a/xllm/core/scheduler/scheduler_factory.h +++ b/xllm/core/scheduler/scheduler_factory.h @@ -18,6 +18,7 @@ limitations under the License. #include "runtime/xservice_client.h" #include "scheduler/continuous_scheduler.h" #include "scheduler/dit_scheduler.h" +#include "scheduler/fixsteps_scheduler.h" namespace xllm { @@ -25,6 +26,10 @@ std::unique_ptr create_continuous_scheduler( Engine* engine, ContinuousScheduler::Options options); +std::unique_ptr create_fixsteps_scheduler( + Engine* engine, + ContinuousScheduler::Options options); + std::unique_ptr create_dit_scheduler( DiTEngine* engine, DiTScheduler::Options options); diff --git a/xllm/core/util/CMakeLists.txt b/xllm/core/util/CMakeLists.txt index 3318822e..cd2fbd8d 100644 --- a/xllm/core/util/CMakeLists.txt +++ b/xllm/core/util/CMakeLists.txt @@ -31,6 +31,7 @@ cc_library( SRCS device_name_utils.cpp env_var.cpp + hash_util.cpp json_reader.cpp net.cpp pretty_print.cpp @@ -50,7 +51,9 @@ cc_library( Boost::serialization absl::synchronization ${Python_LIBRARIES} + proto::xllm_proto :platform + SMHasherSupport ) target_link_libraries(util PRIVATE OpenSSL::SSL OpenSSL::Crypto) add_dependencies(util brpc-static) @@ -70,8 +73,3 @@ cc_test( ) target_link_libraries(util_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto) add_dependencies(util_test brpc-static) - - - - - diff --git a/xllm/core/util/env_var.cpp b/xllm/core/util/env_var.cpp index 928b792d..c68766cd 100644 --- a/xllm/core/util/env_var.cpp +++ b/xllm/core/util/env_var.cpp @@ -22,6 +22,9 @@ limitations under the License. namespace xllm { namespace util { +// Environment variable keys +const std::string EXTRA_THREAD_NUM = "EXTRA_THREAD_NUM"; + bool get_bool_env(const std::string& key, bool defaultValue) { const char* val = std::getenv(key.c_str()); if (val == nullptr) { diff --git a/xllm/core/util/env_var.h b/xllm/core/util/env_var.h index c10a61fe..2ce843c0 100644 --- a/xllm/core/util/env_var.h +++ b/xllm/core/util/env_var.h @@ -20,6 +20,9 @@ limitations under the License. namespace xllm { namespace util { +// Environment variable keys +extern const std::string EXTRA_THREAD_NUM; + bool get_bool_env(const std::string& key, bool defaultValue); // Get an integer value from an environment variable. diff --git a/xllm/core/util/hash_util.cpp b/xllm/core/util/hash_util.cpp new file mode 100644 index 00000000..f335f545 --- /dev/null +++ b/xllm/core/util/hash_util.cpp @@ -0,0 +1,55 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "hash_util.h" + +#include +#include + +#include "third_party/smhasher/src/MurmurHash3.h" + +// Use a constant seed instead of FLAGS to avoid circular dependency +constexpr uint32_t MURMUR_HASH3_SEED = 0; + +namespace xllm { + +void murmur_hash3(const uint8_t* pre_hash_value, + const Slice& token_ids, + uint8_t* hash_value) { + if (pre_hash_value == nullptr) { + MurmurHash3_x64_128(reinterpret_cast(token_ids.data()), + sizeof(int32_t) * token_ids.size(), + MURMUR_HASH3_SEED, + hash_value); + } else { + uint8_t key[1024]; + + int32_t data_len = + sizeof(int32_t) * token_ids.size() + MURMUR_HASH3_VALUE_LEN; + CHECK_GT(sizeof(key), data_len) << "key size is too small"; + + memcpy(key, pre_hash_value, MURMUR_HASH3_VALUE_LEN); + memcpy(key + MURMUR_HASH3_VALUE_LEN, + reinterpret_cast(token_ids.data()), + sizeof(int32_t) * token_ids.size()); + + MurmurHash3_x64_128(reinterpret_cast(key), + data_len, + MURMUR_HASH3_SEED, + hash_value); + } +} + +} // namespace xllm diff --git a/xllm/core/util/hash_util.h b/xllm/core/util/hash_util.h index 31393d5b..e4886ee4 100644 --- a/xllm/core/util/hash_util.h +++ b/xllm/core/util/hash_util.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "slice.h" + namespace xllm { constexpr uint32_t MURMUR_HASH3_VALUE_LEN = 16; @@ -62,4 +64,8 @@ struct FixedStringKeyEqual { } }; +void murmur_hash3(const uint8_t* pre_hash_value, + const Slice& token_ids, + uint8_t* hash_value); + } // namespace xllm diff --git a/xllm/core/util/tensor_helper.h b/xllm/core/util/tensor_helper.h index 714810c9..418d1cef 100644 --- a/xllm/core/util/tensor_helper.h +++ b/xllm/core/util/tensor_helper.h @@ -50,6 +50,58 @@ inline torch::Tensor create_2d_tensor(const std::vector >& vec, return tensor; }; +// 为空的2D vector提供特殊优化版本 +template +inline torch::Tensor create_2d_tensor_optimized( + const std::vector >& vec, + torch::ScalarType dtype) { + if (vec.empty()) { + return {}; + } + + const size_t n_rows = vec.size(); + const size_t n_cols = vec.empty() ? 0 : vec[0].size(); + + // 对于全零矩阵的特殊优化 + bool all_zero = true; + for (const auto& row : vec) { + for (const auto& val : row) { + if (val != T(0)) { + all_zero = false; + break; + } + } + if (!all_zero) break; + } + + if (all_zero) { + // 直接创建零tensor,更高效 + return torch::zeros( + {static_cast(n_rows), static_cast(n_cols)}, + torch::TensorOptions() + .dtype(dtype) + .device(torch::kCPU) + .pinned_memory(true)); + } + + // 否则使用优化的内存复制方式 + auto tensor = + torch::empty({static_cast(n_rows), static_cast(n_cols)}, + torch::TensorOptions() + .dtype(dtype) + .device(torch::kCPU) + .pinned_memory(true)); + + // 优化:使用批量内存复制替代逐行torch::tensor创建 + T* tensor_data = tensor.data_ptr(); + for (int64_t i = 0; i < n_rows; ++i) { + CHECK_EQ(vec[i].size(), n_cols); + // 直接复制内存,避免创建临时tensor + std::memcpy(tensor_data + i * n_cols, vec[i].data(), n_cols * sizeof(T)); + } + return tensor; +}; + inline torch::Tensor safe_to(const torch::Tensor& t, const torch::TensorOptions& options, bool non_blocking = false) { diff --git a/xllm/core/util/utils.cpp b/xllm/core/util/utils.cpp index 0182b687..d3c25164 100644 --- a/xllm/core/util/utils.cpp +++ b/xllm/core/util/utils.cpp @@ -148,5 +148,103 @@ std::vector cal_vec_split_index(uint32_t vec_size, return split_index; } +torch::Dtype convert_rec_type_to_torch(proto::DataType data_type) { + // Future extensions go here. + switch (data_type) { + case proto::DataType::FLOAT: + return torch::kFloat32; + + case proto::DataType::BFLOAT16: + return torch::kBFloat16; + + case proto::DataType::BOOL: + return torch::kBool; + + case proto::DataType::UINT8: + return torch::kUInt8; + + // case proto::DataType::UINT32: + // return torch::kUInt32; + + case proto::DataType::INT8: + return torch::kInt8; + + case proto::DataType::INT16: + return torch::kInt16; + + default: + throw std::runtime_error("Unsupported data type: " + + std::to_string(static_cast(data_type))); + } +} + +torch::Tensor convert_rec_tensor_to_torch( + const proto::InferInputTensor& input_tensor) { + std::vector shape; + shape.reserve(input_tensor.shape_size()); + for (int i = 0; i < input_tensor.shape_size(); ++i) { + shape.push_back(input_tensor.shape(i)); + } + + if (!input_tensor.has_contents()) { + throw std::runtime_error("Input tensor '" + input_tensor.name() + + "' has no contents"); + } + + const auto& contents = input_tensor.contents(); + torch::Dtype dtype = convert_rec_type_to_torch(input_tensor.data_type()); + + switch (dtype) { + case torch::kFloat32: { + // Directly use protobuf's float array + const auto& data = contents.fp32_contents(); + return torch::from_blob( + const_cast(data.data()), + shape, + torch::dtype(torch::kFloat32).requires_grad(false)) + .clone(); // Clone to ensure independent memory + } + // not support now. + // case torch::kFloat16: { + // // Need type conversion (protobuf usually stores float16 as uint16) + // const auto& data = contents.bytes_contents(); + // std::vector half_data; + // half_data.reserve(data.size()); + // for (auto val : data) { + // half_data.push_back(static_cast(val)); + // } + // return torch::tensor(half_data, torch::dtype(torch::kFloat16)) + // .view(shape); + // } + + case torch::kInt32: { + const auto& data = contents.int_contents(); + return torch::from_blob(const_cast(data.data()), + shape, + torch::dtype(torch::kInt32)) + .clone(); + } + + case torch::kInt64: { + const auto& data = contents.int64_contents(); + return torch::from_blob(const_cast(data.data()), + shape, + torch::dtype(torch::kInt64)) + .clone(); + } + + case torch::kBool: { + const auto& data = contents.bool_contents(); + return torch::tensor(std::vector(data.begin(), data.end()), + torch::dtype(torch::kBool)) + .view(shape); + } + + default: + throw std::runtime_error("Unhandled data type conversion for: " + + std::to_string(static_cast(dtype))); + } +} + } // namespace util } // namespace xllm diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 51491972..3c95ca84 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "rec.pb.h" #include "slice.h" namespace xllm { @@ -71,5 +72,8 @@ bool match_suffix(const Slice& data, const Slice& suffix); std::vector cal_vec_split_index(uint32_t vec_size, uint32_t part_num); +torch::Tensor convert_rec_tensor_to_torch( + const proto::InferInputTensor& input_tensor); + } // namespace util } // namespace xllm diff --git a/xllm/models/model_registry.cpp b/xllm/models/model_registry.cpp index 1fab6325..8546e998 100644 --- a/xllm/models/model_registry.cpp +++ b/xllm/models/model_registry.cpp @@ -21,6 +21,7 @@ limitations under the License. #include #include "models.h" +#include "rec/onerec.h" namespace { diff --git a/xllm/models/rec/onerec.h b/xllm/models/rec/onerec.h new file mode 100644 index 00000000..d88c8463 --- /dev/null +++ b/xllm/models/rec/onerec.h @@ -0,0 +1,1054 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include + +#include "core/util/utils.h" + +#ifdef USE_NPU +#include +#include + +#include "atb_speed/log.h" +#include "core/layers/attention_mask.h" +#include "core/layers/lm_head.h" +#include "core/layers/onerec_block_layer.h" +#include "core/layers/rms_norm.h" +#include "core/layers/word_embedding.h" +#endif + +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_args.h" +#include "core/layers/linear.h" +#include "core/util/tensor_helper.h" +#include "models/model_registry.h" + +namespace xllm { + +// Helper function to pad encoder output from [ntokens, hidden_size] to [bs, +// max_seq_len, hidden_size] +inline torch::Tensor pad_encoder_output(const torch::Tensor& encoder_output, + const ModelInputParams& input_params) { + const int64_t bs = input_params.rec_params->bs; + const int64_t hidden_size = encoder_output.size(1); + + // Get actual sequence lengths and max sequence length from input_params + const auto& seq_lens = input_params.rec_params->encoder_seq_lens; + const int64_t max_seq_len = input_params.rec_params->encoder_max_seq_len; + + // Split encoder_output into individual sequences + std::vector seq_list; + seq_list.reserve(bs); + + int64_t token_offset = 0; + for (int64_t i = 0; i < bs; ++i) { + const int64_t seq_len = seq_lens[i]; + seq_list.emplace_back(encoder_output.narrow(0, token_offset, seq_len)); + token_offset += seq_len; + } + + // Use PyTorch's built-in padding function for better performance + auto padded_output = torch::nn::utils::rnn::pad_sequence( + seq_list, /*batch_first=*/true, /*padding_value=*/0.0); + + // Ensure the output has the correct max_seq_len dimension + if (padded_output.size(1) < max_seq_len) { + auto extra_padding = + torch::zeros({bs, max_seq_len - padded_output.size(1), hidden_size}, + encoder_output.options()); + padded_output = torch::cat({padded_output, extra_padding}, /*dim=*/1); + } + + return padded_output; +} + +#ifdef USE_NPU +class OneRecBlockImpl : public torch::nn::Module { + public: + OneRecBlockImpl(const ModelContext& context, + int layer_idx = 0, + bool is_decoder = true) { + // register submodules + block_layer_ = register_module( + "block_layer", layer::OneRecBlockLayer(context, is_decoder, layer_idx)); + } + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + atb::Context* context, + AtbWorkspace& work_space, + std::vector event, + std::vector*> event_flag, + int layer_id, + const torch::Tensor& encoder_output = torch::Tensor(), + const torch::Tensor& expert_array = torch::Tensor()) { + // ONEREC now passes position_bias through attn_mask with ALIBI mask type + // Pass encoder_output to the underlying block_layer_ + return block_layer_->forward( + x, + attn_mask, + kv_cache, + input_params, + context, + work_space, + event, + event_flag, + encoder_output.defined() ? const_cast(&encoder_output) + : nullptr, + layer_id, + expert_array); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + block_layer_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const { + block_layer_->verify_loaded_weights(prefix); + } + + void merge_loaded_weights() { block_layer_->merge_loaded_weights(); } + + private: + layer::OneRecBlockLayer block_layer_{nullptr}; +}; +TORCH_MODULE(OneRecBlock); + +// ONEREC position bias computation similar to ONERECAttention.compute_bias +// Optimized for different stages: +// - Encoder: only prefill stage, bidirectional attention, can handle long +// sequences (~2000 tokens) +// - Decoder prefill: single token, no causal mask needed +// - Decoder decode: incremental generation, only compute bias for last query +// position +inline torch::Tensor compute_onerec_position_bias( + int64_t query_length, + int64_t key_length, + int64_t num_heads, + bool is_decoder, + layer::WordEmbedding& position_bias_embedding, + atb::Context* context, + AtbWorkspace& workspace, + int64_t num_buckets = 32, + int64_t max_distance = 128, + const torch::TensorOptions& options = torch::kFloat32, + bool is_decode_stage = false, + const ModelInputParams* input_params = nullptr) { + auto device = options.device(); + auto dtype = options.dtype(); + + // For decoder decode stage, we need full key_length but only last query + int64_t actual_query_length = is_decode_stage ? key_length : query_length; + + // Ensure minimum valid dimensions to avoid empty tensors + if (actual_query_length <= 0) { + LOG(WARNING) << "[ONEREC DEBUG] actual_query_length <= 0 (" + << actual_query_length << "), using 1"; + actual_query_length = 1; + } + if (key_length <= 0) { + LOG(WARNING) << "[ONEREC DEBUG] key_length <= 0 (" << key_length + << "), using 1"; + key_length = 1; + } + + // LOG(INFO) << "[ONEREC DEBUG] compute_onerec_position_bias - query_length: " + // << query_length << ", key_length: " << key_length + // << ", actual_query_length: " << actual_query_length + // << ", is_decode_stage: " << is_decode_stage; + + // Create position indices + auto context_position = + torch::arange(actual_query_length, + torch::dtype(torch::kLong).device(device)) + .unsqueeze(1); + auto memory_position = + torch::arange(key_length, torch::dtype(torch::kLong).device(device)) + .unsqueeze(0); + + // Calculate relative position: memory_position - context_position + auto relative_position = memory_position - context_position; + + // Convert to relative position buckets (similar to ONEREC's + // _relative_position_bucket) + auto relative_buckets = torch::zeros_like(relative_position); + + if (!is_decoder) { + // Bidirectional for encoder + num_buckets = num_buckets / 2; + relative_buckets += (relative_position > 0).to(torch::kLong) * num_buckets; + relative_position = torch::abs(relative_position); + } else { + // Unidirectional for decoder + relative_position = + -torch::min(relative_position, torch::zeros_like(relative_position)); + } + + // Half buckets for exact increments + auto max_exact = num_buckets / 2; + auto is_small = relative_position < max_exact; + + // Logarithmic buckets for larger distances + auto relative_position_if_large = + max_exact + (torch::log(relative_position.to(torch::kFloat) / max_exact) / + std::log(static_cast(max_distance) / max_exact) * + (num_buckets - max_exact)) + .to(torch::kLong); + + relative_position_if_large = + torch::min(relative_position_if_large, + torch::full_like(relative_position_if_large, num_buckets - 1)); + + relative_buckets += + torch::where(is_small, relative_position, relative_position_if_large); + + // Use the learned position bias embedding table + // AtbWordEmbedding expects 1D input tensor, so we need to flatten + // relative_buckets + auto original_shape = relative_buckets.sizes(); + auto flattened_buckets = relative_buckets.flatten(); + + auto values = position_bias_embedding(flattened_buckets, 0); + + // Handle AtbWordEmbedding output: since unpadInputs=true, it returns 2D + // [num_tokens, hidden_size] We need to reshape it to [query_length, + // key_length, num_heads] + if (values.dim() == 2) { + if (values.size(0) == flattened_buckets.size(0)) { + // values is [flattened_size, num_heads], reshape to [query_length, + // key_length, num_heads] + values = + values.view({original_shape[0], original_shape[1], values.size(1)}); + // LOG(INFO) << "[ONEREC DEBUG] Reshaped 2D values from embedding: " + // << values.sizes(); + } else { + LOG(FATAL) << "[ONEREC DEBUG] Unexpected 2D values size: " + << values.sizes() + << ", expected first dim: " << flattened_buckets.size(0); + } + } else if (values.dim() == 1) { + // values is [flattened_size], add num_heads dimension and reshape + values = + values.unsqueeze(-1).expand({flattened_buckets.size(0), num_heads}); + values = values.view({original_shape[0], original_shape[1], num_heads}); + // LOG(INFO) << "[ONEREC DEBUG] Expanded and reshaped 1D values: " + // << values.sizes(); + } else { + LOG(FATAL) << "[ONEREC DEBUG] Unexpected values tensor dimension: " + << values.dim() << ", sizes: " << values.sizes(); + } + + // Debug: Log tensor dimensions before permute + // LOG(INFO) << "[ONEREC DEBUG] Before permute - values.sizes(): " << + // values.sizes() + // << ", relative_buckets.sizes(): " << relative_buckets.sizes() + // << ", query_length: " << query_length + // << ", key_length: " << key_length << ", num_heads: " << num_heads + // << ", is_decoder: " << is_decoder; + + // Now values should be [query_length, key_length, num_heads] after reshaping + // LOG(INFO) << "[ONEREC DEBUG] After embedding reshape - values.sizes(): " + // << values.sizes() << ", expected: [" << actual_query_length << + // "," + // << key_length << ", " << num_heads << "]"; + + if (values.dim() == 3) { + // values is [query_length, key_length, num_heads], permute to [num_heads, + // query_length, key_length] ATB ALIBI mask type requires 3D tensor, not + // 4D, so we don't add batch dimension + values = values.permute({2, 0, 1}); + // LOG(INFO) << "[ONEREC DEBUG] 3D values after permute - values.sizes(): " + // << values.sizes(); + // LOG(INFO) << "position bias after permute " << values + // << ", value device: " << values.device(); + } else if (values.dim() == 2) { + // Fallback: if still 2D, assume it's [query_length, key_length] and add + // num_heads dimension + values = values.unsqueeze(-1).expand( + {values.size(0), values.size(1), num_heads}); + values = values.permute({2, 0, 1}); + // LOG(INFO) << "[ONEREC DEBUG] Fallback 2D handling - values.sizes(): " + // << values.sizes(); + } else { + LOG(FATAL) << "[ONEREC DEBUG] Unexpected values tensor dimension: " + << values.dim() << ", sizes: " << values.sizes(); + } + + // For decoder decode stage, handle batch with different sequence progress + if (is_decode_stage && input_params != nullptr && + !input_params->kv_seq_lens_vec.empty()) { + // In decode stage with batch processing, each sequence may have different + // progress Use max(kv_cu_seq_lens_vec) for query_length and key_length, + // then slice for each sequence + /* + int batch_size = input_params->kv_cu_seq_lens_vec.size(); + std::vector req_bias_vec; + req_bias_vec.reserve(batch_size); + for (int i = 0; i < batch_size; i++) { + // Each sequence takes its corresponding column from the position bias + // matrix + int seq_kv_len = input_params->kv_cu_seq_lens_vec[i]; + // Take the last query row and slice to the sequence's kv length + // values is now 3D [num_heads, query_length, key_length] + auto req_bias_slice = + values.slice(1, -1, values.size(1)).slice(2, 0, seq_kv_len); + req_bias_vec.emplace_back(req_bias_slice); + } + values = torch::cat(req_bias_vec, 2); // Concatenate along key dimension + */ + int seq_kv_len = input_params->kv_seq_lens_vec[0]; + // Take the last query row and slice to the sequence's kv length + // values is now 3D [num_heads, query_length, key_length] + values = values.slice(1, -1, values.size(1)).slice(2, 0, seq_kv_len); + } else if (is_decode_stage) { + // Original logic for single sequence or when input_params is not available + // values is now 3D [num_heads, query_length, key_length] + values = values.slice(1, -1, values.size(1)); // Take last query row + } + + return values; +} + +#endif + +class OneRecStackImpl : public torch::nn::Module { + public: + OneRecStackImpl(const ModelContext& context, + bool is_decode, + layer::WordEmbedding& embed_tokens) { +#ifdef USE_NPU + auto args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + + hidden_size_ = args.hidden_size(); + // register submodules + blocks_ = register_module("block", torch::nn::ModuleList()); + uint32_t num_layers = is_decode ? args.n_layers() : args.n_encoder_layers(); + layers_.reserve(num_layers); + + is_decoder_ = is_decode; + use_absolute_position_embedding_ = args.use_absolute_position_embedding(); + use_moe_ = args.use_moe(); + num_experts_per_tok_ = args.num_experts_per_tok(); + relative_attention_num_buckets_ = args.relative_attention_num_buckets(); + relative_attention_max_distance_ = args.relative_attention_max_distance(); + work_space_ = AtbWorkspace(options.device()); + + // share the word embedding + embed_tokens_ = embed_tokens; + num_heads_ = is_decode ? args.decoder_n_heads() : args.n_heads(); + + // Initialize position bias embedding table for relative attention only when + // not using absolute position embedding This replaces the random embedding + // table in compute_onerec_position_bias + if (!use_absolute_position_embedding_) { + position_bias_embedding_ = register_module("position_bias_embedding", + layer::WordEmbedding(context)); + } + + norm_ = register_module("final_layer_norm", layer::RmsNorm(context)); + + // Initialize rotary position embeddings (for compatibility) + cos_pos_ = torch::Tensor(); + sin_pos_ = torch::Tensor(); + + // Initialize attention mask + int32_t mask_value = -9984; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + max_seq_len_ = args.max_position_embeddings(); + atb::Status st = atb::CreateContext(&context_); + LOG_IF(ERROR, st != 0) << "ContextFactory create atb::Context fail"; + device_id = options.device().index(); + void* stream = c10_npu::getCurrentNPUStream(device_id).stream(); + LOG_IF(ERROR, stream == nullptr) << "get current stream fail"; + // context_->SetExecuteStream(atb_speed::Utils::GetCurrentStream()); + context_->SetExecuteStream(stream); + context_->SetAsyncTilingCopyStatus(true); + for (int32_t i = 0; i < num_layers; i++) { + auto block = OneRecBlock(context, i, is_decode); + layers_.push_back(block); + blocks_->push_back(block); + } + +#endif + } + + ~OneRecStackImpl() { + atb::Status st = atb::DestroyContext(context_); + LOG_IF(ERROR, st != 0) << "DestroyContext atb::Context fail"; + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params, + const torch::Tensor& encoder_output = torch::Tensor()) { +#ifdef USE_NPU + // Get embeddings + torch::Tensor h; + if (input_params.rec_params->is_hybrid_mode && !is_decoder_) { + h = tokens; + } else { + if (input_params.rec_params->decoder_context_embedding.defined()) { + // use context embedding replacing bos + prompt tokens + if (tokens.sizes() == 0) { + h = input_params.rec_params->decoder_context_embedding; + } else { + h = embed_tokens_(tokens, 0); + + // Reshape tensors for interleaving + // decoder_context_embedding: [bs * group_width * seq_len2, + // hidden_size] h: [bs * group_width * seq_len1, hidden_size] + auto& context_emb = + input_params.rec_params->decoder_context_embedding; + auto& token_emb = h; + const int64_t hidden_size = context_emb.size(3); + const int64_t bs = input_params.rec_params->bs; + const int64_t group_width = input_params.rec_params->group_width; + + const int64_t context_total_tokens = context_emb.size(2); + const int64_t token_total_tokens = token_emb.size(0); + + // Assume bs * group_width is the same for both tensors + // We need to determine seq_len1 and seq_len2 from the tensor shapes + // For now, assume seq_len2 is provided via input_params.seq_len or + // can be inferred + const int64_t bs_group = bs * group_width; + const int64_t seq_len1 = token_total_tokens / bs_group; + + // Use pre-allocated combined space from batch.cpp + // context_emb is already in shape [bs, group_width, total_len, + // hidden_size] The first seq_len2 part has been filled with + // context_embedding + const int64_t total_len = context_total_tokens; + const int64_t seq_len2 = total_len - seq_len1; + + // token_emb shape is [bs * group_width * seq_len1, hidden_size] + // Need to reshape to [bs, group_width, seq_len1, + // hidden_size] for corresponding copying + auto token_embedding_reshaped = + token_emb.view({bs, group_width, seq_len1, hidden_size}); + + // Copy token_embedding to the last seq_len1 part of context_emb + // Use narrow to slice from seq_len2 position in dimension 2, taking + // seq_len1 length + context_emb.narrow(2, seq_len2, seq_len1) + .copy_(token_embedding_reshaped); + + // Reshape to final shape + h = context_emb.view({-1, hidden_size}).clone(); + } + if (!h.is_contiguous()) { + h = h.contiguous(); + } + + } else { + h = embed_tokens_(tokens, 0); + } + } + + // Ensure encoder_output is on NPU device if provided + torch::Tensor npu_encoder_output = encoder_output; + if (encoder_output.defined() && + encoder_output.device().type() != h.device().type()) { + npu_encoder_output = encoder_output.to(h.device()); + } + + // Since unpadInputs=true in AtbWordEmbeddingImpl, h is 2D: [total_tokens, + // hidden_size] We need to get sequence info from input_params instead auto + // total_tokens = h.size(0); auto hidden_size = h.size(1); Get batch_size + // and seq_length from input_params + auto batch_size = input_params.num_sequences; + auto seq_length = input_params.q_max_seq_len; + + // Determine stage based on input_params + bool is_prefill = input_params.rec_params->rec_stage == + RecModelInputParams::RecStage::PREFILL; + + // Compute sequence lengths for position bias calculation + auto [query_length, key_length] = + compute_sequence_lengths(seq_length, is_prefill, input_params); + + ModelInputParams& input_params_new = + const_cast(input_params); + bool is_decode_stage = is_decoder_ && !is_prefill; + + // Compute attention mask based on MoE usage + torch::Tensor effective_attn_mask; + if (use_absolute_position_embedding_) { + effective_attn_mask = + create_moe_attention_mask(query_length, h, is_decoder_); + } else { + effective_attn_mask = compute_position_bias_mask( + query_length, key_length, h, is_decode_stage, input_params); + } + + // Pre-process attention mask for better performance + torch::Tensor preprocessed_attn_mask = + preprocess_attention_mask(effective_attn_mask, h); + torch::Tensor preprocessed_encoder_seq_lens_tensor; + + // Pre-process encoder_seq_lens_tensor if defined + if (input_params.rec_params->encoder_seq_lens_tensor.defined()) { + auto target_device = h.device(); + if (input_params.rec_params->encoder_seq_lens_tensor.device() != + target_device) { + auto flattened_tensor = + input_params.rec_params->encoder_seq_lens_tensor.flatten(); + preprocessed_encoder_seq_lens_tensor = + flattened_tensor.to(target_device, torch::kInt).contiguous(); + } else { + auto flattened_tensor = + input_params.rec_params->encoder_seq_lens_tensor.flatten().to( + torch::kInt); + preprocessed_encoder_seq_lens_tensor = + flattened_tensor.is_contiguous() ? flattened_tensor + : flattened_tensor.contiguous(); + } + // Update input_params to use preprocessed tensor + input_params_new.rec_params->encoder_seq_lens_tensor = + preprocessed_encoder_seq_lens_tensor; + } else { + // Even if not defined, copy the original tensor to input_params_new + input_params_new.rec_params->encoder_seq_lens_tensor = + input_params.rec_params->encoder_seq_lens_tensor; + } + + // Create expert_array tensor for MoE support + torch::Tensor expert_array; + if (use_moe_) { + int64_t input_length = h.size(0); + expert_array = torch::arange( + 0, + input_length * num_experts_per_tok_, + torch::TensorOptions().dtype(torch::kInt32).device(h.device())); + } + + for (size_t i = 0; i < layers_.size(); i++) { + if (input_params.layer_synchronizer) { + input_params.layer_synchronizer->synchronize_layer(i); + } + + //@TODO: init + std::vector events; + std::vector*> event_flags; + + auto& layer = layers_[i]; + // Use reference to kv_caches[i] to match KVCache& parameter type + KVCache& kv_cache_ref = kv_caches[i]; + layers_[i]->forward( + h, + cos_pos_, + sin_pos_, + effective_attn_mask, // Pass position_bias as attn_mask + kv_cache_ref, + input_params_new, + context_, + work_space_, + events, + event_flags, + i, + npu_encoder_output, + expert_array); + } + h = norm_(h, 0); + return h; + +#endif + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + // Load position bias embedding weights (from first layer's relative + // attention bias) + if (!use_absolute_position_embedding_) { + position_bias_embedding_->load_state_dict(state_dict.get_dict_with_prefix( + "block.0.layer.0.SelfAttention.relative_attention_bias.")); + } + // call each layer's load_state_dict function + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("block." + std::to_string(i) + ".")); + } + norm_->load_state_dict( + state_dict.get_dict_with_prefix("final_layer_norm.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); + if (!use_absolute_position_embedding_) { + position_bias_embedding_->verify_loaded_weights( + prefix + "block.0.layer.0.SelfAttention.relative_attention_bias."); + } + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->verify_loaded_weights(prefix + "block." + std::to_string(i) + + "."); + } + norm_->verify_loaded_weights(prefix + "final_layer_norm."); + } + +#ifdef USE_NPU + void merge_loaded_weights() { + // test + embed_tokens_->merge_loaded_weights(); + if (!use_absolute_position_embedding_) { + position_bias_embedding_->merge_loaded_weights(); + } + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->merge_loaded_weights(); + } + norm_->merge_loaded_weights(); + } + + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } +#endif + + private: + int64_t hidden_size_; + +#ifdef USE_CUDA + // parameter members, must be registered + ParallelEmbedding embed_tokens_{nullptr}; + // attention handler + std::unique_ptr handler_{nullptr}; + layer::RMSNorm norm_{nullptr}; +#endif +#ifdef USE_NPU + torch::Tensor cos_pos_; + torch::Tensor sin_pos_; + torch::Tensor position_bias_; + atb::Context* context_; + int max_seq_len_ = 0; + int device_id = 0; + bool is_decoder_; + bool use_absolute_position_embedding_ = false; + bool use_moe_ = false; + int64_t relative_attention_num_buckets_ = 32; + int64_t relative_attention_max_distance_ = 128; + int64_t num_heads_ = 4; + int32_t num_experts_per_tok_ = 2; + AtbWorkspace work_space_; + layer::AttentionMask attn_mask_; + layer::WordEmbedding embed_tokens_{nullptr}; + layer::WordEmbedding position_bias_embedding_{nullptr}; + layer::RmsNorm norm_{nullptr}; +#endif + + torch::nn::ModuleList blocks_{nullptr}; + // hold same data but different type as blocks_ to avoid type cast + std::vector layers_; + + // Helper functions for position bias computation + std::pair compute_sequence_lengths( + int64_t seq_length, + bool is_prefill, + const ModelInputParams& input_params) const; + torch::Tensor create_moe_attention_mask(int64_t seq_length, + const torch::Tensor& h, + bool is_decoder) const; + torch::Tensor compute_position_bias_mask( + int64_t query_length, + int64_t key_length, + const torch::Tensor& h, + bool is_decode_stage, + const ModelInputParams& input_params); + torch::Tensor preprocess_attention_mask( + const torch::Tensor& effective_attn_mask, + const torch::Tensor& h) const; +}; +TORCH_MODULE(OneRecStack); + +class OneRecForConditionalGenerationImpl : public torch::nn::Module { + public: + OneRecForConditionalGenerationImpl(const ModelContext& context) { +#ifdef USE_NPU + auto args = context.get_model_args(); + auto options = context.get_tensor_options(); + + device_id = options.device().index(); + work_space_ = AtbWorkspace(options.device()); + use_moe_ = args.use_moe(); + + shared_ = register_module("shared", layer::WordEmbedding(context)); + + // Only initialize encoder when use_moe is false + bool is_decode = false; + encoder_ = + register_module("encoder", OneRecStack(context, is_decode, shared_)); + + is_decode = true; + decoder_ = + register_module("decoder", OneRecStack(context, is_decode, shared_)); + + lm_head_ = register_module("lm_head", layer::LmHead(context)); + + atb::Status st = atb::CreateContext(&context_); + LOG_IF(ERROR, st != 0) << "ContextFactory create atb::Context fail"; + + void* stream = c10_npu::getCurrentNPUStream(device_id).stream(); + LOG_IF(ERROR, stream == nullptr) << "get current stream fail"; + tie_word_embeddings_ = args.tie_word_embeddings(); + scale_factor_ = 1 / sqrt(args.hidden_size()); + context_->SetExecuteStream(stream); + context_->SetAsyncTilingCopyStatus(true); + +#endif + } + + ~OneRecForConditionalGenerationImpl() { + atb::Status st = atb::DestroyContext(context_); + LOG_IF(ERROR, st != 0) << "DestroyContext atb::Context fail"; + } + // Encoder forward pass - processes encoder input tokens + // encoder_tokens: [num_encoder_tokens] encoder input tokens + // encoder_positions: [num_encoder_tokens] encoder token positions + // returns: [num_encoder_tokens, hidden_size] encoder hidden states + torch::Tensor encode_forward(const torch::Tensor& encoder_tokens, + const torch::Tensor& encoder_positions, + const ModelInputParams& input_params) { + // Run encoder with encoder input tokens + std::vector encoder_kv_caches; // Encoder doesn't use KV cache + + auto encoder_output = encoder_( + encoder_tokens, encoder_positions, encoder_kv_caches, input_params); + encoder_output = pad_encoder_output(encoder_output, input_params); + encoder_output_ = encoder_output; + return encoder_output; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params, + const torch::Tensor& encoder_output = torch::Tensor()) { + // ONEREC decoder forward pass with cross-attention to encoder output + auto decoder_output = + decoder_(tokens, positions, kv_caches, input_params, encoder_output_); + return decoder_output; + } + + torch::Tensor forward(std::vector tokens, + std::vector positions, + std::vector& kv_caches, + const std::vector& input_params) { + if (input_params[0].rec_params->is_encoder_forward) { + return encode_forward(tokens[0], positions[0], input_params[0]); + } + return forward( + tokens[0], positions[0], kv_caches, input_params[0], encoder_output_); + } + + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // returns: [num_tokens, vocab_size] + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + // select tokens if provided + auto h = hidden_states; + if (tie_word_embeddings_) { + h = hidden_states * scale_factor_; + } +#ifdef USE_NPU + return lm_head_(h, seleted_idxes, 0); + +#endif + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + + virtual void update_expert_weight(int32_t layer_id) { return; } + + // TODO load model + void load_model(std::unique_ptr loader) { +#ifdef USE_NPU + for (const auto& state_dict : loader->get_state_dicts()) { + shared_->load_state_dict(state_dict->get_dict_with_prefix("shared.")); + encoder_->load_state_dict(state_dict->get_dict_with_prefix("encoder.")); + decoder_->load_state_dict(state_dict->get_dict_with_prefix("decoder.")); + if (tie_word_embeddings_) { + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("shared.")); + } else { + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); + } + } + // verify + shared_->verify_loaded_weights("shared."); + encoder_->verify_loaded_weights("encoder."); + decoder_->verify_loaded_weights("decoder."); + lm_head_->verify_loaded_weights("lm_head."); + + shared_->merge_loaded_weights(); + encoder_->merge_loaded_weights(); + decoder_->merge_loaded_weights(); + lm_head_->merge_loaded_weights(); + LOG(INFO) << "load model done"; +#endif + } + +#ifdef USE_NPU + layer::LmHead get_lm_head() { return lm_head_; } + + void set_lm_head(layer::LmHead& head) { lm_head_ = head; } + + std::vector get_word_embedding() { return {shared_}; } + + void set_word_embedding(std::vector& embedding) { + shared_ = embedding[0]; + } +#endif + + private: +#ifdef USE_NPU + float scale_factor_; + bool tie_word_embeddings_{false}; + bool use_moe_ = false; +#endif + int device_id = 0; + layer::WordEmbedding shared_{nullptr}; + OneRecStack encoder_{nullptr}; + OneRecStack decoder_{nullptr}; + layer::LmHead lm_head_{nullptr}; + AtbWorkspace work_space_; + atb::Context* context_; + torch::Tensor + encoder_output_; // Store encoder output for decoder cross-attention +#ifndef USE_NPU + ColumnParallelLinear lm_head_{nullptr}; +#endif +}; +TORCH_MODULE(OneRecForConditionalGeneration); + +#ifdef USE_NPU +// Implementation of OneRecStackImpl helper functions +inline std::pair OneRecStackImpl::compute_sequence_lengths( + int64_t seq_length, + bool is_prefill, + const ModelInputParams& input_params) const { + int64_t query_length = seq_length; + int64_t key_length = seq_length; + + if (is_decoder_) { + // ONEREC Decoder logic + if (is_prefill) { + // Decoder prefill: query_length = decoder input length, key_length = + // decoder input length for self-attn + query_length = seq_length; + key_length = seq_length; + } else { + // Decoder decode: query_length = 1 (new token), key_length = accumulated + // length + query_length = 1; + if (!input_params.kv_seq_lens_vec.empty()) { + auto max_kv_len = + *std::max_element(input_params.kv_seq_lens_vec.begin(), + input_params.kv_seq_lens_vec.end()); + key_length = max_kv_len; + } else { + key_length = seq_length; + } + // For position bias slicing in different sequences + query_length = key_length; + } + } else { + // ONEREC Encoder logic: always prefill stage with full input sequence + // Use bidirectional attention for full input sequence + // Use encoder_max_seq_len instead of q_max_seq_len for correct position + // bias calculation + auto encoder_seq_length = input_params.rec_params->encoder_max_seq_len; + query_length = encoder_seq_length; + key_length = encoder_seq_length; + } + + return {query_length, key_length}; +} + +inline torch::Tensor OneRecStackImpl::create_moe_attention_mask( + int64_t seq_length, + const torch::Tensor& h, + bool is_decoder) const { + // When use_moe is true, skip position bias computation and use triangular + // mask directly + if (!is_decoder) { + auto effective_attn_mask = + torch::ones({num_heads_, seq_length, seq_length}, h.options()); + return effective_attn_mask; + } + auto mask_value = -9984.0f; + // Create upper triangular mask (offset=1 to exclude diagonal) + auto upper_tri_mask = + torch::triu(torch::ones({seq_length, seq_length}, + torch::dtype(h.dtype()).device(h.device())), + 1); + // Expand mask to match dimensions [num_heads, seq_len, seq_len] + auto expanded_mask = + upper_tri_mask.unsqueeze(0).expand({num_heads_, seq_length, seq_length}); + + // Create base mask filled with zeros + auto effective_attn_mask = + torch::zeros({num_heads_, seq_length, seq_length}, + torch::dtype(h.dtype()).device(h.device())); + // Apply triangular mask + effective_attn_mask.masked_fill_(expanded_mask.to(torch::kBool), mask_value); + return effective_attn_mask; +} + +inline torch::Tensor OneRecStackImpl::compute_position_bias_mask( + int64_t query_length, + int64_t key_length, + const torch::Tensor& h, + bool is_decode_stage, + const ModelInputParams& input_params) { + // Compute position bias for the first layer + auto layer_position_bias = + compute_onerec_position_bias(query_length, + key_length, + num_heads_, + is_decoder_, + position_bias_embedding_, + context_, + work_space_, + relative_attention_num_buckets_, + relative_attention_max_distance_, + torch::dtype(h.dtype()).device(h.device()), + is_decode_stage, + &input_params); + + // Generate appropriate attention mask based on encoder/decoder type + auto effective_attn_mask = layer_position_bias.is_contiguous() + ? layer_position_bias + : layer_position_bias.contiguous(); + + if (is_decoder_ && FLAGS_enable_rec_prefill_only) { + // Use torch::triu to create upper triangular mask and apply it + auto mask_value = -9984.0f; + // Create upper triangular mask (offset=1 to exclude diagonal) + auto upper_tri_mask = + torch::triu(torch::ones({query_length, query_length}, + effective_attn_mask.options()), + 1); + // Expand mask to match effective_attn_mask dimensions [num_heads, seq_len, + // seq_len] + auto expanded_mask = upper_tri_mask.unsqueeze(0).expand( + {num_heads_, query_length, query_length}); + + // Apply mask to all heads using broadcasting (single operation) + effective_attn_mask.masked_fill_(expanded_mask.to(torch::kBool), + mask_value); + } + + return effective_attn_mask; +} + +inline torch::Tensor OneRecStackImpl::preprocess_attention_mask( + const torch::Tensor& effective_attn_mask, + const torch::Tensor& h) const { + if (!effective_attn_mask.defined()) { + return torch::Tensor(); + } + + // Check device compatibility + auto target_device = h.device(); + if (effective_attn_mask.device() != target_device) { + LOG(WARNING) << "[ONEREC Optimization] Moving attn_mask from device " + << effective_attn_mask.device() << " to " << target_device; + return effective_attn_mask.to(target_device).contiguous(); + } else { + return effective_attn_mask.is_contiguous() + ? effective_attn_mask + : effective_attn_mask.contiguous(); + } +} +#endif + +// register the causal model +REGISTER_CAUSAL_MODEL(onerec, OneRecForConditionalGeneration); + +// register the model args +REGISTER_MODEL_ARGS(onerec, [&] { + LOAD_ARG_OR(model_type, "model_type", "onerec"); + LOAD_ARG_OR(dtype, "torch_dtype", "bfloat16"); + LOAD_ARG(n_kv_heads, "num_key_value_heads"); + LOAD_ARG(decoder_n_kv_heads, "decoder_num_key_value_heads"); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(n_heads, "num_heads", 4); + LOAD_ARG_OR(head_dim, "d_kv", 4); + LOAD_ARG_OR_FUNC( + decoder_n_heads, "decoder_num_heads", [&] { return args->n_heads(); }); + LOAD_ARG_OR_FUNC( + decoder_head_dim, "decoder_d_kv", [&] { return args->head_dim(); }); + // decide model type based on vocab size + LOAD_ARG_OR(vocab_size, "vocab_size", 8200); + LOAD_ARG_OR(n_layers, "num_decoder_layers", 4); + LOAD_ARG_OR(n_encoder_layers, "num_layers", 12); + LOAD_ARG_OR(rms_norm_eps, "layer_norm_epsilon", 1e-6); + LOAD_ARG_OR(max_position_embeddings, "max_length", 500); + LOAD_ARG_OR(intermediate_size, "d_ff", 256); + LOAD_ARG_OR(hidden_size, "d_model", 128); + LOAD_ARG_OR(use_absolute_position_embedding, + "use_absolute_position_embedding", + false); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", true); + // moe. reuse deepseekv2 + LOAD_ARG_OR(use_moe, "use_moe", false); + LOAD_ARG_OR(moe_score_func, "moe_score_func", "softmax"); + LOAD_ARG_OR(moe_route_scale, "moe_route_scale", 1.0); + LOAD_ARG_OR(n_routed_experts, "moe_num_experts", 8); + LOAD_ARG_OR(moe_use_shared_experts, "moe_use_shared_experts", false); + LOAD_ARG_OR(n_shared_experts, "moe_num_shared_experts", 0); + LOAD_ARG_OR(num_experts_per_tok, "moe_topk", 2); + LOAD_ARG_OR(moe_intermediate_size, "moe_inter_dim", 1024); + + LOAD_ARG_OR( + relative_attention_num_buckets, "relative_attention_num_buckets", 32); + LOAD_ARG_OR( + relative_attention_max_distance, "relative_attention_max_distance", 128); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 0); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 128001); + + // LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + // return args->hidden_size() / args->n_heads(); + // }); +}); + +REGISTER_TOKENIZER_ARGS(onerec, [&] { SET_ARG(tokenizer_type, "rec"); }); + +} // namespace xllm \ No newline at end of file diff --git a/xllm/proto/CMakeLists.txt b/xllm/proto/CMakeLists.txt index 38be2b75..5014c409 100644 --- a/xllm/proto/CMakeLists.txt +++ b/xllm/proto/CMakeLists.txt @@ -6,6 +6,7 @@ proto_library( SRCS tensor.proto common.proto + rec.proto completion.proto chat.proto multimodal.proto diff --git a/xllm/proto/completion.proto b/xllm/proto/completion.proto index ccddd0e6..37b16738 100644 --- a/xllm/proto/completion.proto +++ b/xllm/proto/completion.proto @@ -4,6 +4,7 @@ option go_package = "jd.com/jd-infer/xllm;xllm"; package xllm.proto; import "common.proto"; +import "rec.proto"; // Next ID: 26 message CompletionRequest { @@ -95,6 +96,9 @@ message CompletionRequest { optional Priority priority = 28; optional int32 beam_width = 29; + + // tensor for rec embedding. + repeated InferInputTensor input_tensors = 30; } message LogProbs { @@ -142,5 +146,8 @@ message CompletionResponse { // usage statistics for the completion request. Usage usage = 6; + + // for rec output + repeated InferOutputTensor output_tensors = 7; } diff --git a/xllm/proto/rec.proto b/xllm/proto/rec.proto new file mode 100644 index 00000000..5504b865 --- /dev/null +++ b/xllm/proto/rec.proto @@ -0,0 +1,119 @@ +syntax = "proto3"; +option go_package = "jd.com/jd-infer/xllm;xllm"; +package xllm.proto; +import "common.proto"; + +option cc_enable_arenas = true; +option cc_generic_services = true; +enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + // Non-IEEE floating-point format based on papers + // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, + // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. + // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. + // The computation usually happens inside a block quantize / dequantize + // fused by the runtime. + FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf + FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero + FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero + // 4-bit integer data types + UINT4 = 21; // Unsigned integer in range [0, 15] + INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation + // 4-bit floating point data types + FLOAT4E2M1 = 23; + // E8M0 type used as the scale for microscaling (MX) formats: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + FLOAT8E8M0 = 24; + // Future extensions go here. +} +// The data contained in a tensor represented by the repeated type +// that matches the tensor's data type. Protobuf oneof is not used +// because oneofs cannot contain repeated fields. +message InferTensorContents +{ + // Representation for BOOL data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated bool bool_contents = 1; + // Representation for INT8, INT16, and INT32 data types. The size + // must match what is expected by the tensor's shape. The contents + // must be the flattened, one-dimensional, row-major order of the + // tensor elements. + repeated int32 int_contents = 2; + // Representation for INT64 data types. The size must match what + // is expected by the tensor's shape. The contents must be the + // flattened, one-dimensional, row-major order of the tensor elements. + repeated int64 int64_contents = 3; + // Representation for UINT8, UINT16, and UINT32 data types. The size + // must match what is expected by the tensor's shape. The contents + // must be the flattened, one-dimensional, row-major order of the + // tensor elements. + repeated uint32 uint_contents = 4; + // Representation for UINT64 data types. The size must match what + // is expected by the tensor's shape. The contents must be the + // flattened, one-dimensional, row-major order of the tensor elements. + repeated uint64 uint64_contents = 5; + // Representation for FP32 data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated float fp32_contents = 6; + // Representation for FP64 data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated double fp64_contents = 7; + // Representation for BYTES data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated bytes bytes_contents = 8; +} +// An input tensor for an inference request. +message InferInputTensor +{ + // The tensor name. + string name = 1; + // The tensor data type. + DataType data_type = 2; + // The tensor shape. + repeated int64 shape = 3; + // The tensor contents using a data-type format. This field must + // not be specified if "raw" tensor contents are being used for + // the inference request. + InferTensorContents contents = 4; +} +// An output tensor returned for an inference request. +message InferOutputTensor +{ + // The tensor name. + string name = 1; + // The tensor data type. + DataType datatype = 2; + // The tensor shape. + repeated int64 shape = 3; + // The tensor contents using a data-type format. This field must + // not be specified if "raw" tensor contents are being used for + // the inference response. + InferTensorContents contents = 4; +} \ No newline at end of file