From f56435a6389137c5f813d4eaa1ac9670948e1a6f Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 17 Sep 2025 23:24:14 +0800 Subject: [PATCH 01/43] eagle impl with top-1 proposal Signed-off-by: fishbell --- samples/cpp/text_generation/CMakeLists.txt | 3 +- .../text_generation/eagle_speculative_lm.cpp | 114 ++++ .../cpp/text_generation/greedy_causal_lm.cpp | 2 +- .../text_generation/eagle_speculative_lm.py | 67 +++ .../genai/continuous_batching_pipeline.hpp | 5 + .../include/openvino/genai/llm_pipeline.hpp | 6 + .../include/openvino/genai/perf_metrics.hpp | 2 +- .../src/continuous_batching/model_runner.hpp | 200 ++++++- src/cpp/src/continuous_batching/pipeline.cpp | 57 +- src/cpp/src/lora/adapter.cpp | 64 +-- src/cpp/src/safe_tensor_wrapper.cpp | 49 ++ src/cpp/src/safe_tensor_wrapper.hpp | 31 + src/cpp/src/{lora => }/safetensors.c | 0 src/cpp/src/sampling/sampler.cpp | 5 + src/cpp/src/sampling/sampler.hpp | 5 + src/cpp/src/sequence_group.cpp | 3 + src/cpp/src/sequence_group.hpp | 12 +- ...batching_for_speculative_decoding_impl.cpp | 77 ++- ...batching_for_speculative_decoding_impl.hpp | 45 ++ .../speculative_decoding_impl.cpp | 534 ++++++++++++++++++ .../speculative_decoding_impl.hpp | 95 ++++ .../update_request_structs.hpp | 8 +- .../accuracy/CMakeLists.txt | 4 + .../continuous_batching_eagle_decoding.cpp | 150 +++++ ...ntinuous_batching_speculative_decoding.cpp | 11 +- 25 files changed, 1465 insertions(+), 84 deletions(-) create mode 100644 samples/cpp/text_generation/eagle_speculative_lm.cpp create mode 100755 samples/python/text_generation/eagle_speculative_lm.py create mode 100644 src/cpp/src/safe_tensor_wrapper.cpp create mode 100644 src/cpp/src/safe_tensor_wrapper.hpp rename src/cpp/src/{lora => }/safetensors.c (100%) create mode 100644 tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp diff --git a/samples/cpp/text_generation/CMakeLists.txt b/samples/cpp/text_generation/CMakeLists.txt index ebaf32c7f4..0fa9d9eb6d 100644 --- a/samples/cpp/text_generation/CMakeLists.txt +++ b/samples/cpp/text_generation/CMakeLists.txt @@ -29,7 +29,8 @@ set (SAMPLE_LIST lora_greedy_causal_lm multinomial_causal_lm prompt_lookup_decoding_lm - speculative_decoding_lm) + speculative_decoding_lm + eagle_speculative_lm) foreach(sample IN LISTS SAMPLE_LIST) add_sample_executable(${sample}) diff --git a/samples/cpp/text_generation/eagle_speculative_lm.cpp b/samples/cpp/text_generation/eagle_speculative_lm.cpp new file mode 100644 index 0000000000..df74aaf59c --- /dev/null +++ b/samples/cpp/text_generation/eagle_speculative_lm.cpp @@ -0,0 +1,114 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "openvino/genai/llm_pipeline.hpp" +#include "openvino/genai/speculative_decoding/perf_metrics.hpp" + +template +void print_perf_metrics(T& perf_metrics, std::string model_name) { + std::cout << "\n" << model_name << std::endl; + auto generation_duration = perf_metrics.get_generate_duration().mean; + std::cout << " Generate time: " << generation_duration << " ms" << std::endl; + std::cout << " TTFT: " << perf_metrics.get_ttft().mean << " ± " << perf_metrics.get_ttft().std << " ms" + << std::endl; + std::cout << " TPOT: " << perf_metrics.get_tpot().mean << " ± " << perf_metrics.get_tpot().std << " ms/token" + << std::endl; + std::cout << " Num generated token: " << perf_metrics.get_num_generated_tokens() << " tokens" << std::endl; + if (model_name == "Total") { + std::cout << " Total iteration number: " << perf_metrics.raw_metrics.m_new_token_times.size() << std::endl; + } else { + std::cout << " Total iteration number: " << perf_metrics.raw_metrics.m_durations.size() << std::endl; + } + if (perf_metrics.get_num_input_tokens() > 0) { + std::cout << " Input token size: " << perf_metrics.get_num_input_tokens() << std::endl; + } +} + +int main(int argc, char* argv[]) try { + if (4 != argc) { + throw std::runtime_error(std::string{"Usage: "} + argv[0] + " ''"); + } + + std::string main_model_path = argv[1]; + std::string eagle_model_path = argv[2]; + std::string prompt = argv[3]; + + // Configure devices - can run main and eagle models on different devices + std::string main_device = "GPU", eagle_device = "GPU"; // CPU can bse used as well + + // Eagle Speculative settings + ov::genai::GenerationConfig config = ov::genai::greedy(); + config.max_new_tokens = 100; + config.num_assistant_tokens = 5; + + ov::genai::SchedulerConfig scheduler_config; + scheduler_config.dynamic_split_fuse = false; // Eagle speculative decoding does not support dynamic_split_fuse mode + // Create pipeline with eagle speculative enabled + ov::genai::LLMPipeline pipe( + main_model_path, + main_device, + ov::genai::draft_model(eagle_model_path, eagle_device), + ov::genai::scheduler_config(scheduler_config), + ov::genai::eagle3_mode(true) + ); + // Setup performance measurement + auto start_time = std::chrono::high_resolution_clock::now(); + + // Optional: Create a streaming callback for real-time token display + auto streamer = [](std::string subword) { + std::cout << subword << std::flush; + return ov::genai::StreamingStatus::RUNNING; + }; + + // Run generation with eagle speculative decoding + std::cout << "Generating with Eagle Speculative decoding:" << std::endl; + auto result = pipe.generate(prompt, config, streamer); + std::cout << std::endl; + + // Calculate and display performance metrics + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time); + std::cout << "\nGeneration completed in " << duration.count() << " ms" << std::endl; + + auto sd_perf_metrics = std::dynamic_pointer_cast(result.extended_perf_metrics); + if (sd_perf_metrics) { + print_perf_metrics(result.perf_metrics, "Total"); + print_perf_metrics(sd_perf_metrics->main_model_metrics, "MAIN MODEL"); + std::cout << " accepted token: " << sd_perf_metrics->get_num_accepted_tokens() << " tokens" << std::endl; + std::cout << " compress rate: " + << sd_perf_metrics->main_model_metrics.get_num_generated_tokens() * 1.0f / + sd_perf_metrics->main_model_metrics.raw_metrics.m_durations.size() + << std::endl; + print_perf_metrics(sd_perf_metrics->draft_model_metrics, "DRAFT MODEL"); + } + std::cout << std::endl; + + // Run without Eagle for comparison + std::cout << "\n-----------------------------" << std::endl; + std::cout << "Generating without Eagle Speculative decoding:" << std::endl; + + // Disable Eagle mode + /*config.eagle_model = false; + + start_time = std::chrono::high_resolution_clock::now(); + pipe.generate(prompt, config, streamer); + std::cout << std::endl; + */ + end_time = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end_time - start_time); + std::cout << "\nStandard generation completed in " << duration.count() << " ms" << std::endl; + +} catch (const std::exception& error) { + try { + std::cerr << error.what() << '\n'; + } catch (const std::ios_base::failure&) {} + return EXIT_FAILURE; +} catch (...) { + try { + std::cerr << "Non-exception object thrown\n"; + } catch (const std::ios_base::failure&) {} + return EXIT_FAILURE; +} \ No newline at end of file diff --git a/samples/cpp/text_generation/greedy_causal_lm.cpp b/samples/cpp/text_generation/greedy_causal_lm.cpp index ca5e193da1..341ba26d93 100644 --- a/samples/cpp/text_generation/greedy_causal_lm.cpp +++ b/samples/cpp/text_generation/greedy_causal_lm.cpp @@ -9,7 +9,7 @@ int main(int argc, char* argv[]) try { std::string models_path = argv[1]; std::string prompt = argv[2]; - std::string device = "CPU"; // GPU can be used as well + std::string device = "GPU"; // GPU can be used as well ov::genai::LLMPipeline pipe(models_path, device); ov::genai::GenerationConfig config; diff --git a/samples/python/text_generation/eagle_speculative_lm.py b/samples/python/text_generation/eagle_speculative_lm.py new file mode 100755 index 0000000000..2ce64446be --- /dev/null +++ b/samples/python/text_generation/eagle_speculative_lm.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import openvino_genai +import queue + +def streamer(subword): + print(subword, end='', flush=True) + # Return flag corresponds whether generation should be stopped. + return openvino_genai.StreamingStatus.RUNNING + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('model_dir') + parser.add_argument('draft_model_dir') + parser.add_argument('prompt') + args = parser.parse_args() + + # User can run main and draft model on different devices. + # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in openvino_genai.draft_model` for draft. + main_device = 'GPU' # CPU can be used as well + draft_device = 'GPU' + scheduler_config = openvino_genai.SchedulerConfig() + scheduler_config.dynamic_split_fuse = False # Eagle speculative decoding does not support dynamic_split_fuse mode + draft_model = openvino_genai.draft_model(args.draft_model_dir, draft_device) + + pipe = openvino_genai.LLMPipeline(args.model_dir, main_device, scheduler_config = scheduler_config, draft_model=draft_model, eagle3_mode = True) + + config = openvino_genai.GenerationConfig() + config.max_new_tokens = 100 + # Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded + # add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration + config.num_assistant_tokens = 5 + # add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold` + # config.assistant_confidence_threshold = 0.4 + + # Since the streamer is set, the results will be printed + # every time a new token is generated and put into the streamer queue. + res = pipe.generate([args.prompt], config, streamer) + print() + if (res.extended_perf_metrics): + main_model_metrics = res.extended_perf_metrics.main_model_metrics + print(f"MAIN MODEL") + print(f" Generate time: {main_model_metrics.get_generate_duration().mean:.2f} ms" ) + print(f" TTFT: {main_model_metrics.get_ttft().mean:.2f} ± {main_model_metrics.get_ttft().std:.2f} ms" ) + print(f" TTST: {main_model_metrics.get_ttst().mean:.2f} ± {main_model_metrics.get_ttst().std:.2f} ms/token") + print(f" TPOT: {main_model_metrics.get_tpot().mean:.2f} ± {main_model_metrics.get_tpot().std:.2f} ms/iteration") + print(f" AVG Latency: {main_model_metrics.get_latency().mean:.2f} ± {main_model_metrics.get_latency().std:.2f} ms/token") + print(f" Num generated token: {main_model_metrics.get_num_generated_tokens()} tokens") + print(f" Total iteration number: {len(main_model_metrics.raw_metrics.m_durations)}") + print(f" Num accepted token: {res.extended_perf_metrics.get_num_accepted_tokens()} tokens") + + draft_model_metrics = res.extended_perf_metrics.draft_model_metrics + print(f"DRAFT MODEL" ) + print(f" Generate time: {draft_model_metrics.get_generate_duration().mean:.2f} ms" ) + print(f" TTFT: {draft_model_metrics.get_ttft().mean:.2f} ms") + print(f" TTST: {draft_model_metrics.get_ttst().mean:.2f} ms/token") + print(f" TPOT: {draft_model_metrics.get_tpot().mean:.2f} ± {draft_model_metrics.get_tpot().std:.2f} ms/token") + print(f" AVG Latency: {draft_model_metrics.get_latency().mean:.2f} ± {draft_model_metrics.get_latency().std:.2f} ms/iteration") + print(f" Num generated token: {draft_model_metrics.get_num_generated_tokens()} tokens") + print(f" Total iteration number: {len(draft_model_metrics.raw_metrics.m_durations)}") + print() + +if '__main__' == __name__: + main() diff --git a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp index 1a84192ead..488562a869 100644 --- a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp +++ b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp @@ -65,13 +65,18 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline { class ContinuousBatchingImpl; class ContinuousBatchingForSpeculativeDecodingImpl; + class ContinuousBatchingForEagleDecodingImpl; class ContinuousBatchingForPromptLookupImpl; class SpeculativeDecodingImpl; + class EagleDecodingImpl; class PromptLookupImpl; friend class ContinuousBatchingForSpeculativeDecodingImpl; + friend class ContinuousBatchingForPromptLookupImpl; + friend class ContinuousBatchingForEagleDecodingImpl; friend class SpeculativeDecodingImpl; + friend class EagleDecodingImpl; friend class PromptLookupImpl; std::shared_ptr m_impl; diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index eea94591c3..f99c76b3c1 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -345,6 +345,12 @@ static constexpr ov::Property prompt_lookup{"prompt_lookup"}; */ static constexpr ov::Property enable_save_ov_model{"enable_save_ov_model"}; +/** +* @brief enable eagle3_mode property serves to activate eagle3 speculative decoding. +* Set `true` to activate this mode. +* And create LLMPipeline instance with this config. +*/ +static constexpr ov::Property eagle3_mode{"eagle3_mode"}; } // namespace genai } // namespace ov diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index 763067e946..565a069803 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -134,7 +134,7 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics { MeanStdPair detokenization_duration = {-1.0f, -1.0f}; size_t num_generated_tokens; - size_t num_input_tokens; + size_t num_input_tokens = 0; float get_load_time(); // Load time in ms. size_t get_num_generated_tokens(); diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index 1c757da35a..a417b598e9 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -49,6 +49,12 @@ class ModelRunner { // Output shape: [1, conversation length, hidden_size]. EmbeddingsModel::Ptr m_embedding; + bool m_is_hidden_state_export_needed = false; // need to export hidden state after inference + bool m_is_hidden_state_import_needed = false; // need to import hidden state from another model runner + bool m_is_hidden_state_internal_needed = false; // need to use internal hidden state, e.g, eagle2 + std::map, std::pair> m_sequence_hidden_state_mapping; // pre-requisite: main/draft have same seq group and running seq grouped id + // a container which use sequence group id and request id as key to store hidden states + std::map m_initial_hidden_states; // shape: [N, seq_len, hidden_size] public: /** * Constructs the ModelRunner. @@ -95,6 +101,18 @@ class ModelRunner { return m_request; } + void set_hidden_state_export_needed(bool is_needed) { + m_is_hidden_state_export_needed = is_needed; + } + + void set_hidden_state_import_needed(bool is_needed) { + m_is_hidden_state_import_needed = is_needed; + } + + void set_hidden_state_internal_needed(bool is_needed) { + m_is_hidden_state_internal_needed = is_needed; + } + void set_embedding_model(const EmbeddingsModel::Ptr& embedder) { m_embedding = embedder; } @@ -121,6 +139,47 @@ class ModelRunner { m_cache_rotation_deltas_for_each_layer = std::move(rotation_deltas_for_each_layer); } + ov::Tensor get_hidden_state(size_t request_id, size_t seq_grouped_id) const { + if (m_hidden_states.get_size() == 0) { + return ov::Tensor(); + } + + auto key = std::make_pair(request_id, seq_grouped_id); + auto it = m_sequence_hidden_state_mapping.find(key); + if (it == m_sequence_hidden_state_mapping.end()) { + return ov::Tensor(); + } + + size_t start_idx = it->second.first; + size_t length = it->second.second; + + auto shape = m_hidden_states.get_shape(); + if (shape.size() < 2) { + return ov::Tensor(); + } + + size_t hidden_size = shape[shape.size() - 1]; + + ov::Coordinate start_coord(shape.size(), 0); + ov::Coordinate end_coord(shape.size(), 0); + + start_coord[0] = start_idx; + end_coord[0] = start_idx + length; + + for (size_t i = 1; i < shape.size(); ++i) { + start_coord[i] = 0; + end_coord[i] = shape[i]; + } + + return ov::Tensor(m_hidden_states, start_coord, end_coord); + } + + void set_initial_hidden_state(size_t request_id, const ov::Tensor& hidden_state) { + // m_initial_hidden_states.clear(); + //auto key = std::make_pair(request_id, seq_grouped_id); + m_initial_hidden_states[request_id] = hidden_state; + } + /** * Runs the forward inference call on the underlying LLM's ov::InferRequest, scheduling for inferencing tokens for given sequences * taking into account the supplied scheduler output struct. @@ -129,6 +188,7 @@ class ModelRunner { * @return An ov::Tensor with next-token logit scores for each sequence processed during this `forward` call. */ ov::Tensor forward(const std::vector & sequence_groups, const Scheduler::Output& scheduler_output) { + m_sequence_hidden_state_mapping.clear(); size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size(); size_t batch_size_in_sequences = 0; @@ -164,7 +224,29 @@ class ModelRunner { // block_indices are handled in a special fashion below block_indices_begins(ov::element::i32, {batch_size_in_sequences + 1}), max_context_len(ov::element::i32, {}); - + ov::Tensor hidden_state_input; + float* hidden_state_data = nullptr; + if (m_is_hidden_state_import_needed || m_is_hidden_state_internal_needed) { + if (hidden_size == 0) { + for (const auto& entry : m_initial_hidden_states) { + const auto& stored_hidden_state = entry.second; + if (stored_hidden_state.get_size() > 0) { + auto shape = stored_hidden_state.get_shape(); + if (shape.size() >= 2) { + hidden_size = shape[shape.size() - 1]; + if (!m_is_hidden_state_import_needed) + hidden_size /= 3; + break; + } + } + } + } + if (hidden_size > 0) { + hidden_state_input = ov::Tensor(ov::element::f32, {total_num_tokens, 1, hidden_size}); + hidden_state_data = hidden_state_input.data(); + std::memset(hidden_state_data, 0, total_num_tokens * hidden_size * sizeof(float)); + } + } ov::Tensor score_aggregation_window(ov::element::i32, {batch_size_in_sequences}); ov::Tensor generated_ids_embeds; @@ -205,6 +287,7 @@ class ModelRunner { matmul_gathering_is_available = true; } catch (const ov::Exception&) {} + size_t current_token_idx = 0; std::map> seq_id_to_skipped_blocks_map; for (size_t i = 0; i < num_sequence_groups; ++i) { @@ -236,6 +319,75 @@ class ModelRunner { output_seq_len = 0; Sequence::CPtr sequence = running_sequences[seq_idx]; + if (m_is_hidden_state_export_needed) { + size_t start_token_idx = current_token_idx; + size_t sequence_length = num_scheduled_tokens; + + auto key = std::make_pair(sequence_group->get_request_id(), sequence->get_grouped_id()); + m_sequence_hidden_state_mapping[key] = std::make_pair(start_token_idx, sequence_length); + } + if (m_is_hidden_state_import_needed && hidden_state_data && hidden_size > 0) { + //auto key = std::make_pair(sequence_group->get_request_id(), sequence->get_grouped_id()); + auto it = m_initial_hidden_states.find(sequence_group->get_request_id()); + + if (it != m_initial_hidden_states.end()) { + const auto& stored_hidden_state = it->second; + + if (stored_hidden_state.get_size() > 0) { + auto stored_shape = stored_hidden_state.get_shape(); + + if (stored_shape.size() >= 2) { + size_t stored_seq_len = stored_shape[0]; + size_t stored_hidden_size = stored_shape[stored_shape.size() - 1]; + + if (stored_hidden_size == hidden_size) { + if (stored_seq_len == total_num_tokens) { + hidden_state_input = stored_hidden_state; // all tokens from eagle is accepted + } else { + size_t copy_length = std::min(stored_seq_len, num_scheduled_tokens); + + size_t source_start_idx = + stored_seq_len >= copy_length ? stored_seq_len - copy_length : 0; + + const float* source_data = stored_hidden_state.data(); + float* target_data = hidden_state_data + current_token_idx * hidden_size; + + for (size_t token_offset = 0; token_offset < copy_length; ++token_offset) { + size_t source_offset = (source_start_idx + token_offset) * hidden_size; + size_t target_offset = token_offset * hidden_size; + + std::copy_n(source_data + source_offset, + hidden_size, + target_data + target_offset); + } + } + } + } + } + } + } else { + // fill hidden_state_data with m_hidden_states + if (hidden_state_data) { + std::memset(hidden_state_data + current_token_idx * hidden_size, + 0, + num_scheduled_tokens * hidden_size * sizeof(float)); + auto hidden_state = running_sequences[seq_idx]->get_hidden_state(); + if (hidden_state.get_size() > 0) { + auto shape = hidden_state.get_shape(); + if (shape.size() >= 2 && shape[shape.size() - 1] == hidden_size) { + size_t seq_len = shape[0]; + size_t copy_length = std::min(seq_len, num_scheduled_tokens); + const float* source_data = hidden_state.data(); + float* target_data = hidden_state_data + current_token_idx * hidden_size; + for (size_t token_offset = 0; token_offset < copy_length; ++token_offset) { + size_t source_offset = (seq_len - token_offset - 1) * hidden_size; + size_t target_offset = token_offset * hidden_size; + std::copy_n(source_data + source_offset, hidden_size, target_data + target_offset); + } + } + } + } + } for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) { // compute token for current sequence if (sequence_group_type == SequenceGroupType::TOKENS) { @@ -310,7 +462,7 @@ class ModelRunner { *score_aggregation_window_data = 1; } } - + current_token_idx += num_scheduled_tokens; position_ids_data += num_scheduled_tokens; past_lens_data += 1; subsequence_begins_data += 1; @@ -329,7 +481,31 @@ class ModelRunner { m_request.set_tensor("token_type_ids", token_type_ids); } } - + if (hidden_state_input && hidden_state_input.get_size() > 0) { + if (m_is_hidden_state_import_needed) { + try { + m_request.set_tensor("target_hidden_state_input", hidden_state_input); + auto shape = hidden_state_input.get_shape(); + shape[-1] = shape [-1]/3; + ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); + auto fake_data = fake_tensor.data(); + std::memset(fake_data, 0, fake_tensor.get_byte_size()); + m_request.set_tensor("internal_hidden_state_input", fake_tensor); + } catch (const ov::Exception& e) { + } + } else { + try { + m_request.set_tensor("internal_hidden_state_input", hidden_state_input); + auto shape = hidden_state_input.get_shape(); + shape[-1] = shape [-1] * 3; + ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); + auto fake_data = fake_tensor.data(); + std::memset(fake_data, 0, fake_tensor.get_byte_size()); + m_request.set_tensor("target_hidden_state_input", fake_tensor); + } catch (const ov::Exception& e) { + } + } + } // typical LLM parameters m_request.set_tensor("position_ids", position_ids); @@ -373,6 +549,23 @@ class ModelRunner { _reset_cache_rotation_coefficients(); + if (m_is_hidden_state_export_needed) { + try { + m_hidden_states = m_request.get_tensor("last_hidden_state"); + for (size_t i = 0; i < num_sequence_groups; ++i) { + size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; + SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id]; + std::vector running_sequences = sequence_group->get_running_sequences(); + for (size_t seq_idx = 0; seq_idx < running_sequences.size(); ++seq_idx) { + Sequence::Ptr sequence = running_sequences[seq_idx]; + sequence->update_hidden_state( + get_hidden_state(sequence_group->get_request_id(), sequence->get_grouped_id())); + } + } + } catch (const ov::Exception&) { + m_hidden_states = ov::Tensor(); + } + } // return logits return m_request.get_tensor("logits"); } @@ -434,6 +627,7 @@ class ModelRunner { } private: + ov::Tensor m_hidden_states; // Fills indices for sequences in the order defined by scheduler_output void _fill_indices_from_block_tables( const std::vector& dst_tensor_names, diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index 41cbc0d07b..a73cceabc6 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -15,6 +15,7 @@ #include "continuous_batching/timer.hpp" #include "utils.hpp" #include "visual_language/inputs_embedder.hpp" +#include "safe_tensor_wrapper.hpp" using namespace ov::genai; @@ -29,6 +30,16 @@ extract_draft_model_from_config(ov::AnyMap& config) { return draft_model; } +bool +extact_eagle_mode_from_config(ov::AnyMap& config) { + bool eagle_mode = false; + if (config.find(ov::genai::eagle3_mode.name()) != config.end()) { + eagle_mode = config.at(ov::genai::eagle3_mode.name()).as(); + config.erase(ov::genai::eagle3_mode.name()); + } + return eagle_mode; +} + bool extract_prompt_lookup_from_config(ov::AnyMap& config) { bool res = false; @@ -55,6 +66,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); + auto eagle_mode = extact_eagle_mode_from_config(properties_without_draft_model); auto model = utils::read_model(models_path, properties); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); @@ -70,6 +82,19 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + // Eagle speculative decoding does not support dynamic_split_fuse mode + // because it requires hidden state interaction from main model to draft model + // to be implemented future + OPENVINO_ASSERT(scheduler_config.dynamic_split_fuse == false, "Eagle speculative decoding does not support dynamic_split_fuse mode"); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr); + // parse d2t from safe tensors + ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(models_path / "eagle3.safetensor")); + if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional + std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + } } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); @@ -94,12 +119,11 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - + auto eagle_mode = extact_eagle_mode_from_config(properties_without_draft_model); auto model = utils::read_model(models_path, properties_without_draft_model); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); auto generation_config = utils::from_config_json_if_exists(models_path); - std::shared_ptr embedder; if (std::filesystem::exists(models_path / "openvino_text_embeddings_model.xml")) { embedder = std::make_shared(models_path, device, properties_without_draft_model_without_gguf); @@ -109,6 +133,21 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + // Eagle speculative decoding does not support dynamic_split_fuse mode + // because it requires hidden state interaction from main model to draft model + // to be implemented future + OPENVINO_ASSERT(scheduler_config.dynamic_split_fuse == false, "Eagle speculative decoding does not support dynamic_split_fuse mode"); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr); + // parse d2t from safe tensors + if (std::filesystem::exists(models_path / "eagle3.safetensor")) { + ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(models_path / "eagle3.safetensor")); + if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional + std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + } + } } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); @@ -135,6 +174,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); + auto eagle_mode = extact_eagle_mode_from_config(properties_without_draft_model); auto model = utils::singleton_core().read_model(model_str, weights_tensor); auto rt_info = model->get_rt_info(); @@ -152,6 +192,19 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + // Eagle speculative decoding does not support dynamic_split_fuse mode + // because it requires hidden state interaction from main model to draft model + // to be implemented future + OPENVINO_ASSERT(scheduler_config.dynamic_split_fuse == false, "Eagle speculative decoding does not support dynamic_split_fuse mode"); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr); + if (eagle_mode) { + // parse d2t from safe tensors + //ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(models_path / "eagle3.safetensor")); + //std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + } } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); diff --git a/src/cpp/src/lora/adapter.cpp b/src/cpp/src/lora/adapter.cpp index 2b186f3fad..1225912d46 100644 --- a/src/cpp/src/lora/adapter.cpp +++ b/src/cpp/src/lora/adapter.cpp @@ -40,13 +40,10 @@ #include "openvino/genai/lora_adapter.hpp" #include "utils.hpp" +#include "safe_tensor_wrapper.hpp" #include "lora/common.hpp" #include "lora/names_mapping.hpp" -extern "C" { - #include "safetensors.h" -} - // FIXME: Remove or move to a dedicated common header #ifdef NDEBUG #define DEBUG_PRINT(X) do {} while(false) @@ -69,65 +66,6 @@ using ConstantVector = std::vector>; using LoRANode = LoRAParts>; using LoRAPartsParser = LoRAParts(const std::string& name)>>; -// Converts Safetensors element type to OV element type. Only part of the types are supported. -ov::element::Type safetensors_to_ov_element_type (int dtype) { - switch(dtype) { - case SAFETENSORS_F32: - return ov::element::f32; - case SAFETENSORS_F16: - return ov::element::f16; - case SAFETENSORS_BF16: - return ov::element::bf16; - default: - OPENVINO_THROW("Not supported safetensors dtype: ", dtype); - } -} - -using ConstantMap = std::map>; - -// Safetensor file parser that deallocates temporary buffers automatically. -// Drop-in replacement for the third party safetensors_File struct. -struct AutoSafetensor: public safetensors_File { - ~AutoSafetensor () { - std::free(tensors); - std::free(metadata); - } -}; - -// The key in the map is a tensor name and the Constant uses a region of memory from the memory block. -// Each Constant holds a shared pointer to the block in the runtime info. -// The memory block will be deallocated when the last Constant is destroyed. -ConstantMap safetensor_to_constant_map(const ov::Tensor& safetensor) { - AutoSafetensor safe_tensors_file{}; - - OPENVINO_ASSERT(safetensors_file_init(safetensor.data(), safetensor.get_byte_size(), &safe_tensors_file) == nullptr, - "Cannot parse safetensor as a Safetensors file format. Safetensors file format is supported only" - ); - - ConstantMap tensors; - for (int i = 0; i < safe_tensors_file.num_tensors; i++) { - safetensors_TensorDescriptor tensor = safe_tensors_file.tensors[i]; - std::string name(tensor.name.ptr, tensor.name.ptr + tensor.name.len); - ov::Shape shape(tensor.shape, tensor.shape + tensor.n_dimensions); - void* ptr = tensor.ptr; // FIXME: needs a non-constant pointer because Tensor doesn't accept a constant pointer - - auto type = safetensors_to_ov_element_type(tensor.dtype); - auto constant = - std::make_shared(type, shape, ptr, nullptr); // wraps existing memory, no ownership - constant->get_rt_info()["__safetensors_buffer_holder"] = safetensor; // to automatically deallocate underlying memory buffer when last constant that holds it is destroyed - tensors[name] = constant; - } - return tensors; -} - -// Reads a file with a given filename expecting Safetensors file format. -// The file data is mmaped to tensor. -ConstantMap read_safetensors(const std::filesystem::path& filename) { - auto safetensor = ov::read_tensor_data(filename); - - return safetensor_to_constant_map(safetensor); -} - // Default LoRA tensor name patterns observed in the existing LoRA adapters, captures the prefix that should correspond // to a layer name in the base model LoRAPartsParser default_lora_patterns () { diff --git a/src/cpp/src/safe_tensor_wrapper.cpp b/src/cpp/src/safe_tensor_wrapper.cpp new file mode 100644 index 0000000000..c46adc863a --- /dev/null +++ b/src/cpp/src/safe_tensor_wrapper.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +#include "safe_tensor_wrapper.hpp" + +ov::element::Type safetensors_to_ov_element_type (int dtype) { + switch(dtype) { + case SAFETENSORS_F32: + return ov::element::f32; + case SAFETENSORS_F16: + return ov::element::f16; + case SAFETENSORS_BF16: + return ov::element::bf16; + case SAFETENSORS_I64: + return ov::element::i64; + case SAFETENSORS_BOOL: + return ov::element::boolean; + default: + OPENVINO_THROW("Not supported safetensors dtype: ", dtype); + } +} + +ConstantMap safetensor_to_constant_map(const ov::Tensor& safetensor) { + AutoSafetensor safe_tensors_file{}; + + OPENVINO_ASSERT(safetensors_file_init(safetensor.data(), safetensor.get_byte_size(), &safe_tensors_file) == nullptr, + "Cannot parse safetensor as a Safetensors file format. Safetensors file format is supported only" + ); + + ConstantMap tensors; + for (int i = 0; i < safe_tensors_file.num_tensors; i++) { + safetensors_TensorDescriptor tensor = safe_tensors_file.tensors[i]; + std::string name(tensor.name.ptr, tensor.name.ptr + tensor.name.len); + ov::Shape shape(tensor.shape, tensor.shape + tensor.n_dimensions); + void* ptr = tensor.ptr; // FIXME: needs a non-constant pointer because Tensor doesn't accept a constant pointer + + auto type = safetensors_to_ov_element_type(tensor.dtype); + auto constant = + std::make_shared(type, shape, ptr, nullptr); // wraps existing memory, no ownership + constant->get_rt_info()["__safetensors_buffer_holder"] = safetensor; // to automatically deallocate underlying memory buffer when last constant that holds it is destroyed + tensors[name] = constant; + } + return tensors; +} + +ConstantMap read_safetensors(const std::filesystem::path& filename) { + auto safetensor = ov::read_tensor_data(filename); + + return safetensor_to_constant_map(safetensor); +} \ No newline at end of file diff --git a/src/cpp/src/safe_tensor_wrapper.hpp b/src/cpp/src/safe_tensor_wrapper.hpp new file mode 100644 index 0000000000..074ce3f1dd --- /dev/null +++ b/src/cpp/src/safe_tensor_wrapper.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +#include "openvino/runtime/core.hpp" +#include "openvino/op/constant.hpp" +extern "C" { + #include "safetensors.h" +} + +using namespace ov::op; +// Converts Safetensors element type to OV element type. Only part of the types are supported. +ov::element::Type safetensors_to_ov_element_type (int dtype); + +using ConstantMap = std::map>; + +// Safetensor file parser that deallocates temporary buffers automatically. +// Drop-in replacement for the third party safetensors_File struct. +struct AutoSafetensor: public safetensors_File { + ~AutoSafetensor () { + std::free(tensors); + std::free(metadata); + } +}; + +// The key in the map is a tensor name and the Constant uses a region of memory from the memory block. +// Each Constant holds a shared pointer to the block in the runtime info. +// The memory block will be deallocated when the last Constant is destroyed. +ConstantMap safetensor_to_constant_map(const ov::Tensor& safetensor); + +// Reads a file with a given filename expecting Safetensors file format. +// The file data is mmaped to tensor. +ConstantMap read_safetensors(const std::filesystem::path& filename); \ No newline at end of file diff --git a/src/cpp/src/lora/safetensors.c b/src/cpp/src/safetensors.c similarity index 100% rename from src/cpp/src/lora/safetensors.c rename to src/cpp/src/safetensors.c diff --git a/src/cpp/src/sampling/sampler.cpp b/src/cpp/src/sampling/sampler.cpp index f34a8e251f..e8363fa792 100644 --- a/src/cpp/src/sampling/sampler.cpp +++ b/src/cpp/src/sampling/sampler.cpp @@ -853,6 +853,11 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr } } } + if (!is_validation_mode_enabled && m_d2t) { // compute token offset for draft model in speculative sampling + ov::Tensor d2t_tensor = m_d2t->get_tensor_view(); + auto d2t = d2t_tensor.data(); + sampled_token.m_index = sampled_token.m_index + (d2t? d2t[sampled_token.m_index] : 0); + } // flag to add sampled token to generated sequence or extend logit processors only bool is_extend_sequence = logit_token_offset == 0 || is_generate_n_tokens || !is_validation_passed; if (is_validation_mode_enabled && !is_extend_sequence) { diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index ffbbcac3e3..96cf497793 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -99,6 +99,7 @@ class Sampler { Tokenizer m_tokenizer; ThreadPool m_thread_pool; + std::shared_ptr m_d2t; // Tensor to store d2t mapping for eagle model public: Sampler(const Sampler& rhs) = delete; Sampler(Sampler&& rhs) = delete; @@ -125,6 +126,10 @@ class Sampler { // pair with map with backend name and corresponding compiler init time, and vector of compile times for each concrete grammar std::pair, std::vector> get_structured_output_times(); void clear_structured_output_compile_times(); + + void set_d2t_for_decoding(std::shared_ptr& d2t) { + m_d2t = d2t; + }; }; class Sampler::GroupBeamSearcher { diff --git a/src/cpp/src/sequence_group.cpp b/src/cpp/src/sequence_group.cpp index 94dc9160d4..3a3b27eec7 100644 --- a/src/cpp/src/sequence_group.cpp +++ b/src/cpp/src/sequence_group.cpp @@ -28,6 +28,9 @@ size_t Sequence::_make_hash(size_t content_length) { // get tokens corresponding to current block if (sequence_group->get_sequence_group_type() == SequenceGroupType::TOKENS) { const auto prompt_ids = sequence_group->get_prompt_ids(); + if (content_length > prompt_ids.size() + m_generated_ids.size()) { + std::cout << "break" << std::endl; + } OPENVINO_ASSERT(content_length <= prompt_ids.size() + m_generated_ids.size()); if (block_start_idx < prompt_ids.size()) { content.insert(content.end(), prompt_ids.begin() + block_start_idx, prompt_ids.begin() + std::min(prompt_ids.size(), content_length)); diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 664ac665cf..5ae6767111 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -44,6 +44,7 @@ class Sequence { LogProbs m_generated_log_probs; uint64_t m_grouped_id; uint64_t m_id = _get_next_global_sequence_id(); + ov::Tensor m_hidden_state = ov::Tensor(); SequenceStatus m_status = SequenceStatus::RUNNING; GenerationFinishReason m_finish_reason = GenerationFinishReason::NONE; float m_cumulative_log_prob = 0.0f; @@ -67,6 +68,7 @@ class Sequence { Sequence(const Sequence& seq, const uint64_t id) : m_generated_ids(seq.m_generated_ids), m_grouped_id(id), + m_hidden_state(seq.m_hidden_state), m_status(seq.m_status), m_cumulative_log_prob(seq.m_cumulative_log_prob), m_sequence_group(seq.m_sequence_group), @@ -134,6 +136,14 @@ class Sequence { m_generated_ids.push_back(token_id); } + void update_hidden_state(ov::Tensor tensor) { + m_hidden_state = tensor; + } + + ov::Tensor& get_hidden_state() { + return m_hidden_state; + } + // removes n last tokens and updates cumulative log prob // used to remove stop_string from the output void remove_last_tokens(int n) { @@ -561,7 +571,7 @@ class SequenceGroup : public std::enable_shared_from_this { m_num_validation_tokens = k; } - size_t get_num_tokens_to_validate() { + size_t get_num_tokens_to_validate() const { return m_num_validation_tokens; } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index def8f88372..b3c07e4c38 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -58,12 +58,48 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::get_ge for (const auto& sequence : request->get_running_sequences()) { const auto& sequence_id = sequence->get_grouped_id(); OPENVINO_ASSERT(!generated_request.count(sequence_id)); - generated_request.insert({{sequence_id, { sequence->get_generated_ids(), sequence->get_generated_log_probs() } }}); + generated_request.insert({{sequence_id, { sequence->get_generated_ids(), sequence->get_generated_log_probs(), sequence->get_hidden_state() } }}); } } return result; } +ov::Tensor truncate_hidden_state_from_end(const ov::Tensor& hidden_state, size_t tokens_to_remove) { + if (hidden_state.get_size() == 0 || tokens_to_remove == 0) { + return hidden_state; + } + + auto shape = hidden_state.get_shape(); + if (shape.size() < 2) { + return hidden_state; + } + + size_t seq_len_dim = 0; + size_t current_seq_len = shape[seq_len_dim]; + + if (tokens_to_remove >= current_seq_len) { + ov::Shape new_shape = shape; + new_shape[seq_len_dim] = 0; + return ov::Tensor(hidden_state.get_element_type(), new_shape); + } + + size_t new_seq_len = current_seq_len - tokens_to_remove; + + ov::Coordinate start_coord(shape.size(), 0); + ov::Coordinate end_coord(shape.size(), 0); + + for (size_t i = 0; i < shape.size(); ++i) { + start_coord[i] = 0; + if (i == seq_len_dim) { + end_coord[i] = new_seq_len; + } else { + end_coord[i] = shape[i]; + } + } + + return ov::Tensor(hidden_state, start_coord, end_coord); +} + // { min_len_of_prefix, min_length_of_candidate } std::pair get_prefix_len( @@ -222,6 +258,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update std::vector running_sequences = request->get_running_sequences(); OPENVINO_ASSERT(running_sequences.size() > 0); size_t min_generated_tokens, min_candidate_len; + size_t validate_length = 0; if (running_sequences.front()->get_generated_len() == 0 && !request->get_num_tokens_to_validate()) { m_sampler->create_logit_processor(request_id, request->get_sampling_parameters(), request->get_prompt_ids()); auto& logit_processor = m_sampler->get_logit_processor(request_id); @@ -229,6 +266,9 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update min_generated_tokens = result.inserted_tokens_cnt; running_sequences = request->get_running_sequences(); min_candidate_len = result.inserted_tokens_cnt; + if (eagle_mode_enabled && !m_is_validation_mode_enabled) + m_model_runner->set_initial_hidden_state(request_id, + candidates.begin()->second.hidden_states); } else { // update existing sequences by the candidates auto& logit_processor = m_sampler->get_logit_processor(request_id); @@ -247,6 +287,21 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update candidate_token_ids.resize(min_candidate_len); candidate_token_log_probs.resize(min_candidate_len); result.inserted_tokens_cnt = insert_tokens_to_sequence(running_sequence, candidate_token_ids, candidate_token_log_probs, logit_processor, is_update_logit_processor); + // handle hidden states for eagle mode + if (eagle_mode_enabled && !m_is_validation_mode_enabled && result.inserted_tokens_cnt > 0) { // update hidden states for draft model + // at least there should be one bonus token from main + auto& hidden_state = candidate_sequence.hidden_states; + ov::Tensor pruned_hidden_state = truncate_hidden_state_from_end(hidden_state, result.removed_tokens_cnt); + m_model_runner->set_initial_hidden_state(request_id, + pruned_hidden_state); + validate_length = pruned_hidden_state.get_shape().size() > 0 ? pruned_hidden_state.get_shape()[0] : 0; + } + if (!m_is_validation_mode_enabled && result.inserted_tokens_cnt > 0) { + std::cout << "main update draft: request id: " << request_id << "removed tokens: " << result.removed_tokens_cnt << ", inserted tokens: " << result.inserted_tokens_cnt << std::endl; + } + if (m_is_validation_mode_enabled && result.inserted_tokens_cnt > 0) { + std::cout << "draft update main: request id: " << request_id << "removed tokens: " << result.removed_tokens_cnt << ", inserted tokens: " << result.inserted_tokens_cnt << std::endl; + } } // we should update a logit processor just for draft model to generate the same tokens // logit processors of main model will be updated in sampler while validation mode @@ -261,14 +316,22 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update updated_context_len = min_candidate_len + prompt_len, max_new_tokens = request->get_max_new_tokens(); size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0; - if (generated_len > 0 && result.removed_tokens_cnt > 0) { - request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); + if (validate_length > 0) { + if (generated_len > 0) { + request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1 - (validate_length - 1)); + } + } else { // fast draft or main model for eagle speculative + if (generated_len > 0 && result.removed_tokens_cnt > 0) { + request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); + } } - if (result.inserted_tokens_cnt > 0 && result.removed_tokens_cnt == 0) { + if (validate_length == 0 && result.inserted_tokens_cnt > 0 && result.removed_tokens_cnt == 0) { request->set_num_validated_tokens(result.inserted_tokens_cnt); + } else if (validate_length > 0) { + request->set_num_validated_tokens(validate_length - 1); // in generation stage } // to pause `draft_model` generation in case of `generated_len >= max_new_tokens - 1` to generate last token by `main_model` - if (!m_is_validation_mode_enabled) { + if (!m_is_validation_mode_enabled && result.inserted_tokens_cnt != 0) { bool pause_gen_status = false; generated_len -= result.removed_tokens_cnt; generated_len += result.inserted_tokens_cnt; @@ -323,6 +386,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m raw_perf_metrics.m_batch_sizes.emplace_back(num_generated_tokens); } + if (eagle_mode_enabled) + m_model_runner->set_hidden_state_import_needed(false); to_generate = false; for (auto& request : m_requests) { const auto& sampling_params = request->get_sampling_parameters(); @@ -346,5 +411,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m to_generate |= request->can_generate_tokens(); } } + if (eagle_mode_enabled) + m_model_runner->set_hidden_state_import_needed(true); } } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index f81c7f2d37..69ee41c501 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -38,5 +38,50 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : protected: void finish_request(SequenceGroup::Ptr request); void _pull_awaiting_requests() override {}; + bool eagle_mode_enabled = false; +}; + +class ContinuousBatchingPipeline::ContinuousBatchingForEagleDecodingImpl + : public ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl { +public: + ContinuousBatchingForEagleDecodingImpl() = default; + + ContinuousBatchingForEagleDecodingImpl(const std::shared_ptr& model, + const Tokenizer& tokenizer, + const GenerationConfig& generation_config, + const SchedulerConfig& scheduler_config, + const std::string& device, + const ov::AnyMap& plugin_config, + bool is_validation_mode_enabled) : ContinuousBatchingForSpeculativeDecodingImpl( + model, tokenizer, generation_config, + scheduler_config, device, plugin_config, + is_validation_mode_enabled) { + eagle_mode_enabled = true; + }; + + bool is_requests_empty(); + + void set_d2t_for_draft_decoding(std::shared_ptr& d2t) { + if (m_sampler) { + m_sampler->set_d2t_for_decoding(d2t); + } + } + void set_hidden_state_export_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->set_hidden_state_export_needed(is_needed); + } + } + + void set_hidden_state_import_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->set_hidden_state_import_needed(is_needed); + } + } + + void set_hidden_state_internal_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->set_hidden_state_internal_needed(is_needed); + } + } }; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index c98e0d7542..13e08e5d42 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -365,4 +365,538 @@ std::vector ContinuousBatchingPipeline::SpeculativeDecodingI OPENVINO_ASSERT(main_awaiting_requests.size() == draft_awaiting_requests.size()); return main_awaiting_requests; } + + +void extract_hidden_state_generic(std::shared_ptr& model, + const std::string& model_type, + const std::string& custom_node_name = "") { + if (model_type == "draft") { // for draft model, we always only need to extract last hidden state + std::cout << model_type << " model - last hidden state extraction" << std::endl; + ov::pass::Manager pm; + std::vector layers = {-1}; // -1 means last hidden layer + pm.register_pass(layers); + pm.run_passes(model); + } else { + std::cout << model_type << " model - Eagle 3 hidden state extraction" << std::endl; + ov::pass::Manager pm; + /*if idx==len(self.layers)-3 or idx==len(self.layers)//2 or idx==2: + all_hidden_states += (hidden_states,)*/ + std::vector layers = {2, 16, 29}; // need to add check, only support positive values + pm.register_pass(layers); + pm.run_passes(model); + } +} + +EagleModelTransform::EagleModelTransform(const std::vector& layers) : m_layer_ids(layers) { +} + +bool EagleModelTransform::run_on_model(const std::shared_ptr& model) { + m_new_parameters.clear(); + m_new_results.clear(); + if (m_layer_ids.size() == 1 && m_layer_ids[0] == -1) { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(m_new_parameters, m_new_results); + manager.run_passes(model); + + if (!m_new_results.empty()) { + model->add_results(m_new_results); + std::cout << "EagleModelTransform - Added last hidden output " << std::endl; + } + // input transform for draft + // here we apply a trick for the fc layer in draft model + { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + m_new_parameters = model->get_parameters(); + manager.register_pass(m_new_parameters); + manager.run_passes(model); + + model->add_parameters({m_new_parameters.back()}); + std::cout << "EagleModelTransform - trick on draft model inputs " << std::endl; + } + return true; + } else { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(m_layer_ids, m_hidden_layer_outputs); + manager.run_passes(model); + + if (!m_hidden_layer_outputs.empty()) { + std::cout << "EagleModelTransform - extracted intermediate hidden state outputs " << std::endl; + auto concat = std::make_shared(m_hidden_layer_outputs, -1); + concat->set_friendly_name("eagle3_hidden_states_concat"); + + auto result = std::make_shared(concat); + std::string output_name = "last_hidden_state"; + result->output(0).set_names({output_name}); + result->set_friendly_name(output_name); + model->add_results({result}); + + std::cout << "EagleModelTransform - Added concated eagle3 hidden state output" << std::endl; + return true; + } + } + + return false; +} + +EagleInputTransform::EagleInputTransform(std::vector>& params) { + register_matcher( + std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), + ([¶ms, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, params)) { + ++applied; // FIXME: For debugging purposes only + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "EagleTransform failed to apply"); + } + return false; + }) + ); +} +bool EagleInputTransform::apply(NodePtr node, std::vector>& params) { + if (ov::is_type(node)) { + auto matmul_node = ov::as_type_ptr(node); + if (matmul_node->get_friendly_name().find("__module.model.fc/ov_ext::linear/MatMul") == std::string::npos) { // hardcode for now + return false; + } + auto shape = node->get_output_partial_shape(0); + auto internal_hidden_state = std::make_shared(node->get_element_type(), node->get_output_partial_shape(0)); + internal_hidden_state->output(0).set_names({"internal_hidden_state_input"}); + internal_hidden_state->set_friendly_name("internal_hidden_state_input"); + // create new eltwise node to add output of MatMul node and + auto new_eltwise = std::make_shared(internal_hidden_state, matmul_node->output(0)); + ov::replace_node(matmul_node, new_eltwise); + params.push_back(internal_hidden_state); + std::cout << "EagleInputTransform - Added internal hidden state input parameter" << std::endl; + return true; + } +} + +EagleBaseTransform::EagleBaseTransform(std::vector>& params, std::vector>& results) { + register_matcher( + std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), + ([¶ms, &results, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, params, results)) { + ++applied; // FIXME: For debugging purposes only + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "EagleTransform failed to apply"); + } + return false; + }) + ); +} + +std::shared_ptr EagleBaseTransform::find_last_hidden_node(const std::shared_ptr& start_node, + std::set& visited_nodes) { + if (visited_nodes.count(start_node.get())) { + return nullptr; + } + + visited_nodes.insert(start_node.get()); + + if (ov::is_type(start_node)) { + // check the input nodes of MatMul, if found Gather node, return the gather node, otherwise ,retrun the matmul node + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + // rule out logit processing node + if (ov::as_type_ptr(input_node)) { + return input_node; + } + } + return start_node; + } + + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + + auto result = find_last_hidden_node(input_node, visited_nodes); + if (result) { + return result; + } + } + return nullptr; +} + +std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node, + std::set& visited_nodes) { + if (visited_nodes.count(start_node.get())) { + return nullptr; + } + + visited_nodes.insert(start_node.get()); + + if (ov::is_type(start_node)) { + // check the input nodes of MatMul, if found Gather node, return the gather node, otherwise ,retrun the matmul node + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + if (ov::as_type_ptr(input_node)) { + return start_node; // return the Add node itself + } + } + } + + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + + auto result = find_last_residual_node(input_node, visited_nodes); + if (result) { + return result; + } + } + return nullptr; +} + +std::shared_ptr EagleBaseTransform::find_last_hidden_node(const std::shared_ptr& start_node) { + std::set visited_nodes; + return find_last_hidden_node(start_node, visited_nodes); +} + +std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node) { + std::set visited_nodes; + return find_last_residual_node(start_node, visited_nodes); +} + +bool EagleBaseTransform::apply(NodePtr node, std::vector>& params, std::vector>& results) { + { + // 1. with normalization layer 2. add extra input + if (ov::is_type(node)) { + // we are applying transformation to the last hidden state, eagle2 mode + NodePtr input_node = node->get_input_node_shared_ptr(0); + if (!input_node) { + return false; + } + auto last_residual_node = find_last_residual_node(input_node); + if (!last_residual_node) { + return false; + } + // + auto result = std::make_shared(last_residual_node); + std::string output_name = "last_hidden_state"; + result->output(0).set_names({output_name}); + result->set_friendly_name(output_name); + results.push_back(result); + return true; + } + return false; + } +} +// end of fast draft + +// eagle related transformation +Eagle3Transform::Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs) : m_layers(layers) { + auto is_target_pattern = [&](const Output& output) { + auto add_node = ov::as_type_ptr(output.get_node_shared_ptr()); + auto add_node_name = add_node->get_friendly_name(); + if (add_node_name.find("self_attn") != std::string::npos) + return false; // Skip self-attention layers + bool layer_matched = false; + for (auto layer_idx : m_layers) { + if (add_node_name.find("layers." + std::to_string(layer_idx) + "/") != std::string::npos) { + layer_matched = true; + break; + } + } + + if (!layer_matched) { + return false; // Skip layers that are not in the specified layers + } + auto input0 = add_node->get_input_node_shared_ptr(1); + if (!input0 || !ov::is_type(input0)) { + return false; + } + auto matmul_node = input0; + auto matmul_input = matmul_node->get_input_node_shared_ptr(0); + if (!matmul_input) { + return false; + } + + bool has_multiply = ov::is_type(matmul_input); // ACT(up) dot gate + return has_multiply; + }; + + auto hidden_layer = ov::pass::pattern::wrap_type(is_target_pattern); + register_matcher(std::make_shared(hidden_layer, "Eagle3Transform::hidden_extraction"), + [&hidden_state_outputs, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, hidden_state_outputs)) { + ++applied; // FIXME: For debugging purposes only + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "Eagle3Transform failed to apply"); + } + return false; + } + ); +} + +bool Eagle3Transform::apply(NodePtr node, std::vector>& hidden_state_outputs) { + if (ov::is_type(node)) { + auto add_node = std::dynamic_pointer_cast(node); + if (!add_node) { + return false; + } + hidden_state_outputs.push_back(add_node->output(0)); + return true; + } + return false; +} + +ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc) { + auto main_model = main_model_desc.model; + auto draft_model = draft_model_desc.model; + + auto main_scheduler_config = main_model_desc.scheduler_config; + auto main_device = main_model_desc.device; + + utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction); + utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction); + + utils::apply_gather_before_matmul_transformation(main_model); + utils::apply_gather_before_matmul_transformation(draft_model); + + std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; + bool is_draft_scheduler_undefined = draft_model_desc.scheduler_config == SchedulerConfig(); + + ov::genai::SchedulerConfig main_scheduler_config_updated = main_scheduler_config, + draft_scheduler_config = is_draft_scheduler_undefined + ? main_scheduler_config + : draft_model_desc.scheduler_config; + + if (is_draft_scheduler_undefined) { + // split KV cache to 2 caches for main and draft models + auto compute_total_hidden_size = [](const std::shared_ptr& model) -> size_t { + size_t total_hidden_size = 0; + for (const auto& param_ptr : model->get_parameters()) { + const auto& name = param_ptr->get_friendly_name(); + if (name.find("key_cache.") == 0) { + auto pa_op = param_ptr->get_output_target_inputs(0).begin()->get_node(); + const auto& rt_info = pa_op->get_rt_info(); + total_hidden_size += rt_info.at("num_k_heads").as() * rt_info.at("k_head_size").as() + + rt_info.at("num_v_heads").as() * rt_info.at("v_head_size").as(); + } + } + return total_hidden_size; + }; + float main_model_hidden_size = compute_total_hidden_size(main_model), + draft_model_hidden_size = compute_total_hidden_size(draft_model); + auto k = draft_model_hidden_size / (main_model_hidden_size + draft_model_hidden_size); + + // TODO: work with KV blocks as it will be more precise instead of GBs + size_t main_cache_size = std::ceil(main_scheduler_config.cache_size * (1.f - k)), + draft_cache_size = main_scheduler_config.cache_size - main_cache_size; + if (draft_cache_size == 0 && main_cache_size > 0) { + main_cache_size -= (main_cache_size > 1 ? 1 : 0); + draft_cache_size = 1; + } + + main_scheduler_config_updated.cache_size = main_cache_size; + draft_scheduler_config.cache_size = draft_cache_size; + } + + ov::AnyMap draft_properties = + draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; + + // main and draft model can have different tokenizers + // to do: support retokenization: 154103 + Tokenizer main_model_tokenizer = main_model_desc.tokenizer; + Tokenizer draft_model_tokenizer = draft_model_desc.tokenizer; + + // todo: remove this condition after support of CVS-154103 + OPENVINO_ASSERT(are_tokenizers_equal(main_model_tokenizer, draft_model_tokenizer), + "Tokenizers for draft and main models are different!"); + + m_tokenizer = main_model_tokenizer; + // for eagle model, we need to obtain hidden layer state as extra output + // apply transformations needed to run eagle model + // target model: hidden state extraction, draft model: hidden state import , hidden state extraction + // eagle3 specific : dt importing + extract_hidden_state_generic(main_model, "main", ""); + extract_hidden_state_generic(draft_model, "draft", ""); + + // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode + m_main_pipeline = std::make_shared(main_model, + main_model_tokenizer, + main_model_desc.generation_config, + main_scheduler_config_updated, + main_device, + main_model_desc.properties, + true); + m_draft_pipeline = std::make_shared(draft_model, + draft_model_tokenizer, + draft_model_desc.generation_config, + draft_scheduler_config, + draft_device, + draft_properties, + false); + m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + m_perf_metrics.raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; +} + +ov::Tensor ContinuousBatchingPipeline::EagleDecodingImpl::create_draft_input_ids(const ov::Tensor& original_input_ids) { + auto shape = original_input_ids.get_shape(); + if (shape.size() == 0 || shape.back() <= 1) { + return ov::Tensor(original_input_ids); + } + + size_t original_length = shape.back(); + size_t new_length = original_length - 1; + + ov::Tensor draft_input_ids(ov::element::i64, {1, new_length}); + + const int64_t* src_data = original_input_ids.data(); + int64_t* dst_data = draft_input_ids.data(); + + std::copy(src_data + 1, src_data + original_length, dst_data); + + return draft_input_ids; +} + +std::vector ContinuousBatchingPipeline::EagleDecodingImpl::generate( + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids) { + m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + + OPENVINO_ASSERT(!has_non_finished_requests(), + "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use " + "ContinuousBatchingPipeline::add_request"); + + OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); + + ManualTimer generate_timer("speculative_decoding: generate()"); + generate_timer.start(); + + // checks that all requests has the same LoRA adapters property value + for (size_t i = 1; i < sampling_params.size(); ++i) { + OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, + "LoRA adapters value must be the same for all requests"); + } + auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); + auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); + + m_main_eagle_pipeline->set_adapters(sampling_params[0].adapters); + m_main_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_adapters(sampling_params[0].adapters); + m_draft_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_import_needed(true); + m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); + + const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); + + OPENVINO_ASSERT( + !streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || + (sampling_params[0].is_multinomial() && + sampling_params[0].num_return_sequences == 1)), + "Currently eagle streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); + + std::vector main_generations; + ov::Tensor new_input_ids; + for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { + auto new_input_ids = input_ids[request_id]; + OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); + auto main_sampling_params = sampling_params[request_id]; + + main_generations.push_back( + m_main_pipeline->add_request(request_id, new_input_ids, main_sampling_params)); + + auto draft_sampling_params = sampling_params[request_id]; + // set the parameters do not stop draft generation without stopping of the same request for main pipeline + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + + // remove first token from input_ids to create draft_input_ids + ov::Tensor draft_input_ids = create_draft_input_ids(new_input_ids); + + std::lock_guard lock(m_draft_generations_mutex); + m_draft_generations.insert( + {request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); + } + auto all_requests = get_awaiting_requests(); + + GenerationHandle& generation = main_generations.at(0); + + streamer_ptr->start(); + + while (has_non_finished_requests()) { + try { + step(); + } catch (...) { + drop_requests(); // remove all requests from pipeline state in case of exception + streamer_ptr->end(); + std::rethrow_exception(std::current_exception()); + } + stream_tokens(streamer_ptr, generation); + } + + // waiting for competion of streaming + streamer_ptr->end(); + + OPENVINO_ASSERT(is_requests_empty(), + "Internal error: current request is supposed to be dropped within step() function as completed"); + + std::vector results; + results.reserve(all_requests.size()); + + m_perf_metrics.draft_model_metrics.raw_metrics = m_draft_pipeline->raw_perf_metrics; + generate_timer.end(); + + for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { + const auto& request = all_requests[request_id]; + auto sampling_params = request->get_sampling_parameters(); + const auto& sequences = request->get_finished_sequences(); + size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); + + EncodedGenerationResult result; + result.m_request_id = request_id; + result.m_generation_ids.resize(num_outputs); + result.m_scores.resize(num_outputs); + result.m_status = request->get_generation_stream()->get_status(); + + for (size_t i = 0; i < num_outputs; ++i) { + const auto& sequence = sequences[i]; + const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) + : sequence->get_cumulative_log_prob(); + const auto& generated_ids = sequence->get_generated_ids(); + + if (sampling_params.echo) { + result.m_generation_ids[i] = request->get_prompt_ids(); + } + std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); + result.m_scores[i] = score; + } + + result.m_status = main_generations[request_id]->get_status(); + + // The same perf metrics for each sequence, only tokenization/detokenization will differ. + m_perf_metrics.raw_metrics.generate_durations.clear(); + m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + m_perf_metrics.num_input_tokens = request->get_prompt_len(); + m_perf_metrics.evaluate_statistics(generate_timer.get_start_time()); + + result.perf_metrics = m_perf_metrics; + result.extended_perf_metrics = std::make_shared(m_perf_metrics); + results.push_back(std::move(result)); + } + + OPENVINO_ASSERT(results.size() == input_ids.size()); + return results; +} } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 026d592569..fcb18fb835 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -8,6 +8,18 @@ #include "speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp" #include "speculative_decoding/speculative_decoding_metrics.hpp" #include "openvino/genai/speculative_decoding/perf_metrics.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/result.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/manager.hpp" namespace ov::genai { @@ -51,6 +63,7 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat std::vector get_awaiting_requests(); public: + SpeculativeDecodingImpl() = default; SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); GenerationHandle add_request(uint64_t request_id, @@ -74,4 +87,86 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat SpeculativeDecodingMetrics get_speculative_decoding_metrics(); }; +class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingPipeline::SpeculativeDecodingImpl { +public: + EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); + + std::vector + generate(const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids = std::nullopt) override; + + void fill_hidden_states(const ov::Tensor& hidden_states) { + hiddenstates_tensor = hidden_states; + } + void set_d2t_for_draft_decoding(std::shared_ptr& d2t_tensor) { + auto eagle_impl = std::dynamic_pointer_cast(m_draft_pipeline); + eagle_impl->set_d2t_for_draft_decoding(d2t_tensor); + }; +protected: + //std::shared_ptr m_main_pipeline, m_draft_pipeline; + ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); + ov::Tensor hiddenstates_tensor; +}; + +using NodePtr = std::shared_ptr; +using namespace ov::op; + +class EagleBaseTransform : public ov::pass::MatcherPass { +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("EagleBaseTransform"); + EagleBaseTransform(std::vector>& params, std::vector>& results); + + ~EagleBaseTransform() = default; + +private: + bool apply(NodePtr node, std::vector>& params, std::vector>& results); + size_t applied = 0; + std::string m_eagle_version; + std::shared_ptr find_last_hidden_node(const std::shared_ptr& start_node); + std::shared_ptr find_last_hidden_node(const std::shared_ptr& start_node, + std::set& visited_nodes); + std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node); + std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node, + std::set& visited_nodes); +}; +class EagleInputTransform : public ov::pass::MatcherPass { // eagle3 specific for draft model +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("EagleInputTransform"); + EagleInputTransform(std::vector>& params); + + ~EagleInputTransform() = default; + +private: + bool apply(NodePtr node, std::vector>& params); + size_t applied = 0; +}; +class Eagle3Transform : public ov::pass::MatcherPass { +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("Eagle3Transform"); + Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs); + + ~Eagle3Transform() = default; + +private: + bool apply(NodePtr node, std::vector>& hidden_state_outputs); + size_t applied = 0; + std::vector m_layers; // layers to be abstracted +}; + +class EagleModelTransform : public ov::pass::ModelPass { +public: + EagleModelTransform(const std::vector& layer_ids); + bool run_on_model(const std::shared_ptr& model) override; + +private: + const std::vector m_layer_ids; + std::vector> m_new_results; + std::vector> m_new_parameters; + std::vector> m_hidden_layer_outputs; +}; } diff --git a/src/cpp/src/speculative_decoding/update_request_structs.hpp b/src/cpp/src/speculative_decoding/update_request_structs.hpp index 68f79268f5..58cc405078 100644 --- a/src/cpp/src/speculative_decoding/update_request_structs.hpp +++ b/src/cpp/src/speculative_decoding/update_request_structs.hpp @@ -10,11 +10,13 @@ namespace ov::genai { struct GeneratedSequence { std::vector token_ids; std::vector log_probs; - + ov::Tensor hidden_states; // reserved for eagle speculative GeneratedSequence(const std::vector& generated_token_ids, - const std::vector& generated_log_probs) : + const std::vector& generated_log_probs, + const ov::Tensor& generated_hidden_states) : token_ids(generated_token_ids), - log_probs(generated_log_probs) {}; + log_probs(generated_log_probs), + hidden_states(generated_hidden_states) {}; }; struct UpdateRequestResult { diff --git a/tools/continuous_batching/accuracy/CMakeLists.txt b/tools/continuous_batching/accuracy/CMakeLists.txt index 8223452b5c..3d1bb525c4 100644 --- a/tools/continuous_batching/accuracy/CMakeLists.txt +++ b/tools/continuous_batching/accuracy/CMakeLists.txt @@ -33,6 +33,10 @@ set(TARGET_NAME_CB continuous_batching_speculative_decoding) add_executable(${TARGET_NAME_CB} ${TARGET_NAME_CB}.cpp) target_link_libraries(${TARGET_NAME_CB} PRIVATE openvino::genai cxxopts::cxxopts) +set(TARGET_NAME_CB continuous_batching_eagle_decoding) +add_executable(${TARGET_NAME_CB} ${TARGET_NAME_CB}.cpp) +target_link_libraries(${TARGET_NAME_CB} PRIVATE openvino::genai cxxopts::cxxopts) + set_target_properties(${TARGET_NAME} ${TARGET_NAME_CB} PROPERTIES # Ensure out of box LC_RPATH on macOS with SIP INSTALL_RPATH_USE_LINK_PATH ON) diff --git a/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp b/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp new file mode 100644 index 0000000000..aa4a666c6c --- /dev/null +++ b/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp @@ -0,0 +1,150 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "openvino/genai/continuous_batching_pipeline.hpp" + +void print_cb_generation_result(const ov::genai::GenerationResult& generation_result) { + for (size_t output_id = 0; output_id < generation_result.m_generation_ids.size(); ++output_id) { + std::cout << "Answer " << output_id << " (" << generation_result.m_scores[output_id] << ") : " << generation_result.m_generation_ids[output_id] << std::endl; + } +} + +std::vector get_spec_decoding_generation_config_examples() { + + // sampling param for speulative decoding + ov::genai::GenerationConfig generation_config_greedy_constant = ov::genai::greedy(); + { + generation_config_greedy_constant.num_assistant_tokens = 5; + } + + ov::genai::GenerationConfig generation_config_multinomial_constant =ov::genai::greedy(); + { + generation_config_multinomial_constant.num_return_sequences = 1; + generation_config_multinomial_constant.num_assistant_tokens = 5; + } + + ov::genai::GenerationConfig generation_config_greedy_dynamic = ov::genai::greedy(); + { + generation_config_greedy_dynamic.num_assistant_tokens = 4; + } + + ov::genai::GenerationConfig generation_config_multinomial_dynamic = ov::genai::greedy(); + { + generation_config_multinomial_dynamic.num_return_sequences = 1; + generation_config_multinomial_dynamic.num_assistant_tokens = 4; + } + + return { + generation_config_greedy_constant, + generation_config_multinomial_constant, + generation_config_greedy_dynamic, + generation_config_multinomial_dynamic, + }; +} + +int main(int argc, char* argv[]) try { + // Command line options + + cxxopts::Options options("accuracy_sample", "Help command"); + + options.add_options() + ("n,num_prompts", "A number of prompts", cxxopts::value()->default_value("1")) + ("m,model", "Path to model and tokenizers base directory", cxxopts::value()->default_value(".")) + ("a,draft_model", "Path to assisting model base directory", cxxopts::value()->default_value(".")) + ("d,device", "Target device to run the model", cxxopts::value()->default_value("GPU")) + ("h,help", "Print usage"); + + cxxopts::ParseResult result; + try { + result = options.parse(argc, argv); + } catch (const cxxopts::exceptions::exception& e) { + std::cout << e.what() << "\n\n"; + std::cout << options.help() << std::endl; + return EXIT_FAILURE; + } + + if (result.count("help")) { + std::cout << options.help() << std::endl; + return EXIT_SUCCESS; + } + + const size_t num_prompts = result["num_prompts"].as(); + const std::string models_path = result["model"].as(); + const std::string draft_models_path = result["draft_model"].as(); + const std::string device = result["device"].as(); + + std::vector prompt_examples = { + "What is OpenVINO?", + "How are you?", + "What is your name?", + "Tell me something about Canada", + "What is OpenVINO?", + }; + + auto generation_config = get_spec_decoding_generation_config_examples(); + auto default_config_size = generation_config.size(); + std::vector cb_generation_config; + for (size_t i = 0; i < num_prompts; ++i) { + cb_generation_config.push_back(generation_config[i % default_config_size]); + } + + std::vector prompts(num_prompts); + for (size_t i = 0; i < num_prompts; ++i) { + prompts[i] = prompt_examples[i % prompt_examples.size()]; + } + + ov::genai::SchedulerConfig scheduler_config; + // batch size + scheduler_config.max_num_batched_tokens = 64; + // cache params + scheduler_config.num_kv_blocks = 364; + // mode - vLLM or dynamic_split_fuse + scheduler_config.dynamic_split_fuse = false; // does not support true in eagle speculative decoding + // vLLM specific params + scheduler_config.max_num_seqs = 3; + + ov::genai::ContinuousBatchingPipeline pipe(models_path, scheduler_config, device, {ov::genai::draft_model(draft_models_path, device), std::pair("eagle_mode", ov::Any("EAGLE3"))}); + std::vector generation_results = pipe.generate(prompts, cb_generation_config); + + for (size_t request_id = 0; request_id < generation_results.size(); ++request_id) { + const ov::genai::GenerationResult & generation_result = generation_results[request_id]; + std::cout << "Question: " << prompts[request_id] << std::endl; + switch (generation_result.m_status) + { + case ov::genai::GenerationStatus::FINISHED: + print_cb_generation_result(generation_result); + break; + case ov::genai::GenerationStatus::IGNORED: + std::cout << "Request was ignored due to lack of memory." < 0) { + std::cout << "Partial result:" << std::endl; + print_cb_generation_result(generation_result); + } + break; + case ov::genai::GenerationStatus::STOP: + case ov::genai::GenerationStatus::CANCEL: + std::cout << "Request was aborted." < 0) { + std::cout << "Partial result:" << std::endl; + print_cb_generation_result(generation_result); + } + break; + default: + break; + } + std::cout << std::endl; + } +} catch (const std::exception& error) { + try { + std::cerr << error.what() << '\n'; + } catch (const std::ios_base::failure&) {} + return EXIT_FAILURE; +} catch (...) { + try { + std::cerr << "Non-exception object thrown\n"; + } catch (const std::ios_base::failure&) {} + return EXIT_FAILURE; +} diff --git a/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp b/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp index 0b0797ddf9..3a36dcb2c6 100644 --- a/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp +++ b/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp @@ -20,10 +20,11 @@ std::vector get_spec_decoding_generation_config_exa generation_config_greedy_constant.num_assistant_tokens = 5; } - ov::genai::GenerationConfig generation_config_multinomial_constant = ov::genai::multinomial(); + ov::genai::GenerationConfig generation_config_multinomial_constant =ov::genai::greedy(); { - generation_config_multinomial_constant.num_assistant_tokens = 5; + //generation_config_multinomial_constant.num_assistant_tokens = 5; generation_config_multinomial_constant.num_return_sequences = 1; + generation_config_multinomial_constant.num_assistant_tokens = 5; } ov::genai::GenerationConfig generation_config_greedy_dynamic = ov::genai::greedy(); @@ -31,9 +32,11 @@ std::vector get_spec_decoding_generation_config_exa generation_config_greedy_dynamic.assistant_confidence_threshold = 0.8f; } - ov::genai::GenerationConfig generation_config_multinomial_dynamic = ov::genai::multinomial(); + ov::genai::GenerationConfig generation_config_multinomial_dynamic = ov::genai::greedy(); { - generation_config_multinomial_dynamic.assistant_confidence_threshold = 0.8f; + //generation_config_multinomial_dynamic.assistant_confidence_threshold = 0.8f; + generation_config_multinomial_dynamic.num_return_sequences = 1; + generation_config_multinomial_dynamic.num_assistant_tokens = 5; } return { From 1d74c03087943b99d391fcb612f23490e61869b1 Mon Sep 17 00:00:00 2001 From: fishbell Date: Thu, 18 Sep 2025 01:11:56 +0800 Subject: [PATCH 02/43] enable cb benchmark for eagle3 Signed-off-by: fishbell --- .../speculative_decoding_impl.cpp | 46 +++++++++++++++++-- .../speculative_decoding_impl.hpp | 10 +++- .../continuous_batching_benchmark.cpp | 9 +++- 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 13e08e5d42..e8352ef9dd 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -746,6 +746,7 @@ ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai false); m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); m_perf_metrics.raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; + m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; } ov::Tensor ContinuousBatchingPipeline::EagleDecodingImpl::create_draft_input_ids(const ov::Tensor& original_input_ids) { @@ -767,6 +768,46 @@ ov::Tensor ContinuousBatchingPipeline::EagleDecodingImpl::create_draft_input_ids return draft_input_ids; } +void ContinuousBatchingPipeline::EagleDecodingImpl::update_eagle_pipeline_params() { + auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); + auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); + m_main_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_import_needed(true); + m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); +} + +GenerationHandle +ContinuousBatchingPipeline::EagleDecodingImpl::add_request(uint64_t request_id, + const ov::Tensor& input_ids, + ov::genai::GenerationConfig sampling_params, + std::optional token_type_ids) { + std::lock_guard lock(m_draft_generations_mutex); + auto draft_sampling_params = sampling_params; + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + update_eagle_pipeline_params(); + // remove first token from input_ids to create draft_input_ids + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params, token_type_ids)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params, token_type_ids); +} + +GenerationHandle +ContinuousBatchingPipeline::EagleDecodingImpl::add_request(uint64_t request_id, + const std::string& prompt, + ov::genai::GenerationConfig sampling_params) { + std::lock_guard lock(m_draft_generations_mutex); + auto draft_sampling_params = sampling_params; + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + update_eagle_pipeline_params(); + // remove first token from input_ids to create draft_input_ids + // to be fixed + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, prompt, draft_sampling_params)}); + return m_main_pipeline->add_request(request_id, prompt, sampling_params); +} + std::vector ContinuousBatchingPipeline::EagleDecodingImpl::generate( const std::vector& input_ids, const std::vector& sampling_params, @@ -793,11 +834,8 @@ std::vector ContinuousBatchingPipeline::EagleDecodingIm auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); m_main_eagle_pipeline->set_adapters(sampling_params[0].adapters); - m_main_eagle_pipeline->set_hidden_state_export_needed(true); m_draft_eagle_pipeline->set_adapters(sampling_params[0].adapters); - m_draft_eagle_pipeline->set_hidden_state_export_needed(true); - m_draft_eagle_pipeline->set_hidden_state_import_needed(true); - m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); + update_eagle_pipeline_params(); const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index fcb18fb835..571eb0d9b0 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -97,6 +97,14 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP const StreamerVariant& streamer, std::optional> token_type_ids = std::nullopt) override; + GenerationHandle add_request(uint64_t request_id, + const ov::Tensor& input_ids, + ov::genai::GenerationConfig sampling_params, + std::optional token_type_ids = std::nullopt) override; + + GenerationHandle add_request(uint64_t request_id, + const std::string& prompt, + ov::genai::GenerationConfig sampling_params) override; void fill_hidden_states(const ov::Tensor& hidden_states) { hiddenstates_tensor = hidden_states; } @@ -105,7 +113,7 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP eagle_impl->set_d2t_for_draft_decoding(d2t_tensor); }; protected: - //std::shared_ptr m_main_pipeline, m_draft_pipeline; + void update_eagle_pipeline_params(); ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); ov::Tensor hiddenstates_tensor; }; diff --git a/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp b/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp index afb6dd4265..6810709486 100644 --- a/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp +++ b/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp @@ -437,6 +437,7 @@ int main(int argc, char* argv[]) try { ("dynamic_split_fuse", "Whether to use dynamic split-fuse or vLLM scheduling", cxxopts::value()->default_value("true")) ("m,model", "Path to model and tokenizers base directory", cxxopts::value()->default_value(".")) ("draft_model", "Path to assistant model directory", cxxopts::value()->default_value("")) + ("eagle3_mode", "Whether to enable eagle3 mode for speculative decoding", cxxopts::value()->default_value("false")) ("dataset", "Path to dataset .json file", cxxopts::value()->default_value("./ShareGPT_V3_unfiltered_cleaned_split.json")) ("max_input_len", "Max input length take from dataset", cxxopts::value()->default_value("1024")) ("max_output_len", "Max output length", cxxopts::value()->default_value("2048")) @@ -476,7 +477,7 @@ int main(int argc, char* argv[]) try { const bool use_cache_eviction = result["use_cache_eviction"].as(); bool is_speculative_decoding_enabled = !draft_model_path.empty(); - + bool is_eagle3_mode = result["eagle3_mode"].as(); // Create requests for generation Dataset dataset = filtered_dataset(models_path, dataset_path, num_prompts, max_input_len, max_output_len); @@ -507,6 +508,12 @@ int main(int argc, char* argv[]) try { ov::AnyMap device_config_map = {}; if (is_speculative_decoding_enabled) { device_config_map.insert({ ov::genai::draft_model(draft_model_path) }); + if (is_eagle3_mode) { + // disable dynamic split fuse in eagle3 mode + scheduler_config.dynamic_split_fuse = false; + std::cout << "disable dynamic split fuse for eagle3 speculative decoding" << std::endl; + device_config_map.insert({ ov::genai::eagle3_mode(true) }); + } } if (!parse_plugin_config_string(device_config, device_config_map)) { std::cout << "ERROR: Wrong json parameter in device_config." << std::endl; From 0b09a65b2ccd6654b1f860eff40c27a9aa4e9429 Mon Sep 17 00:00:00 2001 From: fishbell Date: Thu, 18 Sep 2025 18:02:06 +0800 Subject: [PATCH 03/43] add benchmarking, apply copilot review comments Signed-off-by: fishbell --- .../src/continuous_batching/model_runner.hpp | 4 ++-- src/cpp/src/continuous_batching/pipeline.cpp | 8 ++++---- src/cpp/src/sequence_group.cpp | 3 --- ..._batching_for_speculative_decoding_impl.cpp | 6 ------ .../speculative_decoding_impl.cpp | 18 +++++++++++++++--- .../accuracy/CMakeLists.txt | 6 +++--- tools/llm_bench/benchmark.py | 2 ++ tools/llm_bench/llm_bench_utils/model_utils.py | 1 + tools/llm_bench/llm_bench_utils/ov_utils.py | 8 ++++++++ tools/llm_bench/task/text_generation.py | 12 ++++++++++-- 10 files changed, 45 insertions(+), 23 deletions(-) diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index a417b598e9..a6a11b3df7 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -486,7 +486,7 @@ class ModelRunner { try { m_request.set_tensor("target_hidden_state_input", hidden_state_input); auto shape = hidden_state_input.get_shape(); - shape[-1] = shape [-1]/3; + shape[shape.size() - 1] = shape [shape.size() - 1]/3; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); @@ -497,7 +497,7 @@ class ModelRunner { try { m_request.set_tensor("internal_hidden_state_input", hidden_state_input); auto shape = hidden_state_input.get_shape(); - shape[-1] = shape [-1] * 3; + shape[shape.size() - 1] = shape [shape.size() - 1] * 3; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index a73cceabc6..76cd3e4238 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -31,7 +31,7 @@ extract_draft_model_from_config(ov::AnyMap& config) { } bool -extact_eagle_mode_from_config(ov::AnyMap& config) { +extract_eagle_mode_from_config(ov::AnyMap& config) { bool eagle_mode = false; if (config.find(ov::genai::eagle3_mode.name()) != config.end()) { eagle_mode = config.at(ov::genai::eagle3_mode.name()).as(); @@ -66,7 +66,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_mode = extact_eagle_mode_from_config(properties_without_draft_model); + auto eagle_mode = extract_eagle_mode_from_config(properties_without_draft_model); auto model = utils::read_model(models_path, properties); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); @@ -119,7 +119,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_mode = extact_eagle_mode_from_config(properties_without_draft_model); + auto eagle_mode = extract_eagle_mode_from_config(properties_without_draft_model); auto model = utils::read_model(models_path, properties_without_draft_model); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); @@ -174,7 +174,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_mode = extact_eagle_mode_from_config(properties_without_draft_model); + auto eagle_mode = extract_eagle_mode_from_config(properties_without_draft_model); auto model = utils::singleton_core().read_model(model_str, weights_tensor); auto rt_info = model->get_rt_info(); diff --git a/src/cpp/src/sequence_group.cpp b/src/cpp/src/sequence_group.cpp index 3a3b27eec7..94dc9160d4 100644 --- a/src/cpp/src/sequence_group.cpp +++ b/src/cpp/src/sequence_group.cpp @@ -28,9 +28,6 @@ size_t Sequence::_make_hash(size_t content_length) { // get tokens corresponding to current block if (sequence_group->get_sequence_group_type() == SequenceGroupType::TOKENS) { const auto prompt_ids = sequence_group->get_prompt_ids(); - if (content_length > prompt_ids.size() + m_generated_ids.size()) { - std::cout << "break" << std::endl; - } OPENVINO_ASSERT(content_length <= prompt_ids.size() + m_generated_ids.size()); if (block_start_idx < prompt_ids.size()) { content.insert(content.end(), prompt_ids.begin() + block_start_idx, prompt_ids.begin() + std::min(prompt_ids.size(), content_length)); diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index b3c07e4c38..411814630c 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -296,12 +296,6 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update pruned_hidden_state); validate_length = pruned_hidden_state.get_shape().size() > 0 ? pruned_hidden_state.get_shape()[0] : 0; } - if (!m_is_validation_mode_enabled && result.inserted_tokens_cnt > 0) { - std::cout << "main update draft: request id: " << request_id << "removed tokens: " << result.removed_tokens_cnt << ", inserted tokens: " << result.inserted_tokens_cnt << std::endl; - } - if (m_is_validation_mode_enabled && result.inserted_tokens_cnt > 0) { - std::cout << "draft update main: request id: " << request_id << "removed tokens: " << result.removed_tokens_cnt << ", inserted tokens: " << result.inserted_tokens_cnt << std::endl; - } } // we should update a logit processor just for draft model to generate the same tokens // logit processors of main model will be updated in sampler while validation mode diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index e8352ef9dd..807cf8f5c0 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -803,9 +803,21 @@ ContinuousBatchingPipeline::EagleDecodingImpl::add_request(uint64_t request_id, draft_sampling_params.stop_strings = {}; update_eagle_pipeline_params(); // remove first token from input_ids to create draft_input_ids - // to be fixed - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, prompt, draft_sampling_params)}); - return m_main_pipeline->add_request(request_id, prompt, sampling_params); + if (m_model_input_type == ModelInputType::TOKENS) { + static ManualTimer timer("tokenize"); + timer.start(); + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + auto templated_prompt = m_tokenizer.apply_chat_template(history, true); + auto input_ids = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids; + timer.end(); + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params); + } else { + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, prompt, draft_sampling_params)}); + return m_main_pipeline->add_request(request_id, prompt, sampling_params); + } + } std::vector ContinuousBatchingPipeline::EagleDecodingImpl::generate( diff --git a/tools/continuous_batching/accuracy/CMakeLists.txt b/tools/continuous_batching/accuracy/CMakeLists.txt index 3d1bb525c4..b3a29697d3 100644 --- a/tools/continuous_batching/accuracy/CMakeLists.txt +++ b/tools/continuous_batching/accuracy/CMakeLists.txt @@ -33,9 +33,9 @@ set(TARGET_NAME_CB continuous_batching_speculative_decoding) add_executable(${TARGET_NAME_CB} ${TARGET_NAME_CB}.cpp) target_link_libraries(${TARGET_NAME_CB} PRIVATE openvino::genai cxxopts::cxxopts) -set(TARGET_NAME_CB continuous_batching_eagle_decoding) -add_executable(${TARGET_NAME_CB} ${TARGET_NAME_CB}.cpp) -target_link_libraries(${TARGET_NAME_CB} PRIVATE openvino::genai cxxopts::cxxopts) +set(TARGET_NAME_CB_EAGLE continuous_batching_eagle_decoding) +add_executable(${TARGET_NAME_CB_EAGLE} ${TARGET_NAME_CB_EAGLE}.cpp) +target_link_libraries(${TARGET_NAME_CB_EAGLE} PRIVATE openvino::genai cxxopts::cxxopts) set_target_properties(${TARGET_NAME} ${TARGET_NAME_CB} PROPERTIES # Ensure out of box LC_RPATH on macOS with SIP diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index 36ab082558..2eb3c39c9f 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -166,6 +166,8 @@ def get_argprser(): help="Path to file with Continuous Batching Scheduler settings or dict for Speculative decoding of draft model") parser.add_argument("--num_assistant_tokens", required=False, default=None, help="Config option num_assistant_tokens for Speculative decoding and Prompt Lookup decoding", type=int) + parser.add_argument("--eagle3_mode", action="store_true", + help="flag to indicate whether to use eagle3 for speculative decoding") parser.add_argument("--assistant_confidence_threshold", required=False, default=None, help="Config option assistant_confidence_threshold for Speculative decoding", type=float) parser.add_argument("--max_ngram_size", required=False, default=None, diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 6658b7d8fc..490d68e9d7 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -212,6 +212,7 @@ def analyze_args(args): draft_cb_config = get_config(args.draft_cb_config) model_args["draft_cb_config"] = draft_cb_config model_args['num_assistant_tokens'] = args.num_assistant_tokens + model_args['eagle3_mode'] = args.eagle3_mode model_args['assistant_confidence_threshold'] = args.assistant_confidence_threshold model_args['max_ngram_size'] = args.max_ngram_size diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index 9fdb82567e..8ec13fa6eb 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -240,6 +240,14 @@ def create_genai_text_gen_model(model_path, device, ov_config, memory_monitor, * draft_model_load_kwargs = {'scheduler_config': get_scheduler_config_genai(kwargs.get("draft_cb_config"), config_name="draft CB config")}\ if kwargs.get("draft_cb_config") is not None else {} config['draft_model'] = openvino_genai.draft_model(draft_model_path, draft_device.upper(), **draft_model_load_kwargs) + if (kwargs.get('eagle3_mode', None)): + config['eagle3_mode'] = True + if 'scheduler_config' in config.keys(): + config['scheduler_config'].dynamic_split_fuse = False # Eagle speculative decoding does not support dynamic_split_fuse mode + else: + config['scheduler_config'] = openvino_genai.SchedulerConfig() + config['scheduler_config'].dynamic_split_fuse = False + log.info("Eagle3 Speculative Decoding is activated, and dynamic_split_fuse is set to False in scheduler_config") if kwargs.get('max_ngram_size') and kwargs.get('num_assistant_tokens'): log.info("Prompt Lookup decoding is activated") diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index 611464b586..402a654f53 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -284,7 +284,11 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data ) tokenization_start = time.perf_counter() - input_data = tokenizer.encode(input_text_list) + if (args.get("eagle3_mode")): + #eagle3 needs to disable special tokens to ensure compress rate + input_data = tokenizer.encode(input_text_list, add_special_tokens = False) + else: + input_data = tokenizer(input_text_list) tokenization_end = time.perf_counter() tokenization_time = [(tokenization_end - tokenization_start) * 1000] @@ -451,7 +455,11 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg "If it is not expected, please specify --disable_prompt_permutation in your benchmarking command to disable this behavior" ) tok_encode_start = time.perf_counter() - input_data = pipe_tokenizer.encode(input_text_list) + if (args.get("eagle3_mode")): + #eagle3 needs to disable special tokens to ensure compress rate + input_data = pipe_tokenizer.encode(input_text_list, add_special_tokens = False) + else: + input_data = pipe_tokenizer.encode(input_text_list) tok_encode_end = time.perf_counter() input_token_size = input_data.input_ids.shape[1] tok_encode_time = (tok_encode_end - tok_encode_start) * 1000 From b36ecf70a7002c33fcb2bee8126da25a492b9747 Mon Sep 17 00:00:00 2001 From: fishbell Date: Thu, 18 Sep 2025 18:05:23 +0800 Subject: [PATCH 04/43] fix case build failure Signed-off-by: fishbell --- src/cpp/src/speculative_decoding/update_request_structs.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/speculative_decoding/update_request_structs.hpp b/src/cpp/src/speculative_decoding/update_request_structs.hpp index 58cc405078..ddd257adbb 100644 --- a/src/cpp/src/speculative_decoding/update_request_structs.hpp +++ b/src/cpp/src/speculative_decoding/update_request_structs.hpp @@ -13,7 +13,7 @@ struct GeneratedSequence { ov::Tensor hidden_states; // reserved for eagle speculative GeneratedSequence(const std::vector& generated_token_ids, const std::vector& generated_log_probs, - const ov::Tensor& generated_hidden_states) : + const ov::Tensor generated_hidden_states = {}) : token_ids(generated_token_ids), log_probs(generated_log_probs), hidden_states(generated_hidden_states) {}; From d86e5a7eab394a09749b9abf3338b816e4406f22 Mon Sep 17 00:00:00 2001 From: fishbell Date: Thu, 18 Sep 2025 18:11:04 +0800 Subject: [PATCH 05/43] fix SDL Signed-off-by: fishbell --- tools/llm_bench/task/text_generation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index 402a654f53..36c79ac4ff 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -285,8 +285,8 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data tokenization_start = time.perf_counter() if (args.get("eagle3_mode")): - #eagle3 needs to disable special tokens to ensure compress rate - input_data = tokenizer.encode(input_text_list, add_special_tokens = False) + # eagle3 needs to disable special tokens to ensure compress rate + input_data = tokenizer.encode(input_text_list, add_special_tokens=False) else: input_data = tokenizer(input_text_list) tokenization_end = time.perf_counter() @@ -456,8 +456,8 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg ) tok_encode_start = time.perf_counter() if (args.get("eagle3_mode")): - #eagle3 needs to disable special tokens to ensure compress rate - input_data = pipe_tokenizer.encode(input_text_list, add_special_tokens = False) + # eagle3 needs to disable special tokens to ensure compress rate + input_data = pipe_tokenizer.encode(input_text_list, add_special_tokens=False) else: input_data = pipe_tokenizer.encode(input_text_list) tok_encode_end = time.perf_counter() From f0aa2c73f0e9efae9227868cc16a82eae6eec39d Mon Sep 17 00:00:00 2001 From: fishbell Date: Thu, 18 Sep 2025 19:14:29 +0800 Subject: [PATCH 06/43] typo Signed-off-by: fishbell --- tools/llm_bench/task/text_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index 36c79ac4ff..d07bdfaf61 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -288,7 +288,7 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data # eagle3 needs to disable special tokens to ensure compress rate input_data = tokenizer.encode(input_text_list, add_special_tokens=False) else: - input_data = tokenizer(input_text_list) + input_data = tokenizer.encode(input_text_list) tokenization_end = time.perf_counter() tokenization_time = [(tokenization_end - tokenization_start) * 1000] From bbfa8ad9e77ea5131eae06682d39a757e92f18c3 Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 19 Sep 2025 00:47:28 +0800 Subject: [PATCH 07/43] opt hidden state transfer with ROI tensor Signed-off-by: fishbell --- .../src/continuous_batching/model_runner.hpp | 50 ++++++++++++------- .../continuous_batching_eagle_decoding.cpp | 6 +-- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index a6a11b3df7..8db3ab0adb 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -342,24 +342,39 @@ class ModelRunner { if (stored_hidden_size == hidden_size) { if (stored_seq_len == total_num_tokens) { - hidden_state_input = stored_hidden_state; // all tokens from eagle is accepted + hidden_state_input = stored_hidden_state; // all tokens from eagle are accepted } else { size_t copy_length = std::min(stored_seq_len, num_scheduled_tokens); size_t source_start_idx = stored_seq_len >= copy_length ? stored_seq_len - copy_length : 0; - - const float* source_data = stored_hidden_state.data(); - float* target_data = hidden_state_data + current_token_idx * hidden_size; - - for (size_t token_offset = 0; token_offset < copy_length; ++token_offset) { - size_t source_offset = (source_start_idx + token_offset) * hidden_size; - size_t target_offset = token_offset * hidden_size; - - std::copy_n(source_data + source_offset, - hidden_size, - target_data + target_offset); + // Create ROI (sub-tensor) on stored_hidden_state + auto stored_shape = stored_hidden_state.get_shape(); + ov::Coordinate src_start(stored_shape.size(), 0), + src_end(stored_shape.size(), 0); + src_start[0] = source_start_idx; + src_end[0] = source_start_idx + copy_length; + for (size_t d = 1; d < stored_shape.size(); ++d) { + src_start[d] = 0; + src_end[d] = stored_shape[d]; } + ov::Tensor src_roi(stored_hidden_state, src_start, src_end); + + // Create ROI on the destination hidden_state_input at current_token_idx + auto target_shape = + hidden_state_input.get_shape(); // {total_num_tokens, 1, hidden_size} + ov::Coordinate tgt_start(target_shape.size(), 0), + tgt_end(target_shape.size(), 0); + tgt_start[0] = current_token_idx; + tgt_end[0] = current_token_idx + copy_length; + for (size_t d = 1; d < target_shape.size(); ++d) { + tgt_start[d] = 0; + tgt_end[d] = target_shape[d]; + } + ov::Tensor tgt_roi(hidden_state_input, tgt_start, tgt_end); + + // Bulk copy ROI -> ROI + src_roi.copy_to(tgt_roi); } } } @@ -368,6 +383,7 @@ class ModelRunner { } else { // fill hidden_state_data with m_hidden_states if (hidden_state_data) { + OPENVINO_ASSERT(num_scheduled_tokens == 1, "unexpected num_scheduled_tokens in speculative drafting stage in eagle3 mode"); std::memset(hidden_state_data + current_token_idx * hidden_size, 0, num_scheduled_tokens * hidden_size * sizeof(float)); @@ -376,14 +392,10 @@ class ModelRunner { auto shape = hidden_state.get_shape(); if (shape.size() >= 2 && shape[shape.size() - 1] == hidden_size) { size_t seq_len = shape[0]; - size_t copy_length = std::min(seq_len, num_scheduled_tokens); const float* source_data = hidden_state.data(); float* target_data = hidden_state_data + current_token_idx * hidden_size; - for (size_t token_offset = 0; token_offset < copy_length; ++token_offset) { - size_t source_offset = (seq_len - token_offset - 1) * hidden_size; - size_t target_offset = token_offset * hidden_size; - std::copy_n(source_data + source_offset, hidden_size, target_data + target_offset); - } + size_t source_offset = (seq_len - 1) * hidden_size; + std::copy_n(source_data + source_offset, hidden_size, target_data); } } } @@ -486,7 +498,7 @@ class ModelRunner { try { m_request.set_tensor("target_hidden_state_input", hidden_state_input); auto shape = hidden_state_input.get_shape(); - shape[shape.size() - 1] = shape [shape.size() - 1]/3; + shape[shape.size() - 1] = shape [shape.size() - 1] / 3; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); diff --git a/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp b/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp index aa4a666c6c..ccc72c2116 100644 --- a/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp +++ b/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp @@ -98,15 +98,15 @@ int main(int argc, char* argv[]) try { ov::genai::SchedulerConfig scheduler_config; // batch size - scheduler_config.max_num_batched_tokens = 64; + scheduler_config.max_num_batched_tokens = 128; // cache params scheduler_config.num_kv_blocks = 364; // mode - vLLM or dynamic_split_fuse scheduler_config.dynamic_split_fuse = false; // does not support true in eagle speculative decoding // vLLM specific params - scheduler_config.max_num_seqs = 3; + scheduler_config.max_num_seqs = 2; - ov::genai::ContinuousBatchingPipeline pipe(models_path, scheduler_config, device, {ov::genai::draft_model(draft_models_path, device), std::pair("eagle_mode", ov::Any("EAGLE3"))}); + ov::genai::ContinuousBatchingPipeline pipe(models_path, scheduler_config, device, {ov::genai::draft_model(draft_models_path, device), ov::genai::eagle3_mode(true)}); std::vector generation_results = pipe.generate(prompts, cb_generation_config); for (size_t request_id = 0; request_id < generation_results.size(); ++request_id) { From 35e63760b01f8238ce9ef2afed94b834a6619435 Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 19 Sep 2025 17:45:35 +0800 Subject: [PATCH 08/43] opt roi copy interface Signed-off-by: fishbell --- .../src/continuous_batching/model_runner.hpp | 83 ++++++++++++------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index 8db3ab0adb..1d85bb6b62 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -348,36 +348,12 @@ class ModelRunner { size_t source_start_idx = stored_seq_len >= copy_length ? stored_seq_len - copy_length : 0; - // Create ROI (sub-tensor) on stored_hidden_state - auto stored_shape = stored_hidden_state.get_shape(); - ov::Coordinate src_start(stored_shape.size(), 0), - src_end(stored_shape.size(), 0); - src_start[0] = source_start_idx; - src_end[0] = source_start_idx + copy_length; - for (size_t d = 1; d < stored_shape.size(); ++d) { - src_start[d] = 0; - src_end[d] = stored_shape[d]; - } - ov::Tensor src_roi(stored_hidden_state, src_start, src_end); - - // Create ROI on the destination hidden_state_input at current_token_idx - auto target_shape = - hidden_state_input.get_shape(); // {total_num_tokens, 1, hidden_size} - ov::Coordinate tgt_start(target_shape.size(), 0), - tgt_end(target_shape.size(), 0); - tgt_start[0] = current_token_idx; - tgt_end[0] = current_token_idx + copy_length; - for (size_t d = 1; d < target_shape.size(); ++d) { - tgt_start[d] = 0; - tgt_end[d] = target_shape[d]; - } - ov::Tensor tgt_roi(hidden_state_input, tgt_start, tgt_end); - - // Bulk copy ROI -> ROI - src_roi.copy_to(tgt_roi); + copy_roi_between_tensors(stored_hidden_state, source_start_idx, copy_length, hidden_state_input, current_token_idx); } } } + } else { + OPENVINO_ASSERT(false, "missing hidden state from target model to eagle draft model"); } } } else { @@ -392,10 +368,12 @@ class ModelRunner { auto shape = hidden_state.get_shape(); if (shape.size() >= 2 && shape[shape.size() - 1] == hidden_size) { size_t seq_len = shape[0]; - const float* source_data = hidden_state.data(); - float* target_data = hidden_state_data + current_token_idx * hidden_size; - size_t source_offset = (seq_len - 1) * hidden_size; - std::copy_n(source_data + source_offset, hidden_size, target_data); + size_t copy_length = std::min(seq_len, num_scheduled_tokens); + + size_t src_start_idx = seq_len >= copy_length ? seq_len - copy_length : 0; + auto target_shape = ov::Shape{num_scheduled_tokens, 1, hidden_size}; + ov::Tensor target_base(ov::element::f32, target_shape, hidden_state_data + current_token_idx * hidden_size); + copy_roi_between_tensors(hidden_state, src_start_idx, copy_length, target_base, 0); } } } @@ -640,6 +618,49 @@ class ModelRunner { private: ov::Tensor m_hidden_states; + + // Common helper to copy a contiguous slice (first-dim range) from src to dst using ROI tensors. + // src_start_idx: start index along src first dimension + // copy_length: number of elements along first dim to copy + // dst_base: destination base tensor (may be full buffer or a wrapper around a raw pointer) + // dst_first_dim_start: start index in first dimension of dst_base where copy should be placed + static void copy_roi_between_tensors(const ov::Tensor& src, + size_t src_start_idx, + size_t copy_length, + const ov::Tensor& dst_base, + size_t dst_first_dim_start) { + if (copy_length == 0) { + return; + } + + // prepare source ROI coords + const auto src_shape = src.get_shape(); + OPENVINO_ASSERT(!src_shape.empty(), "source tensor rank is zero"); + ov::Coordinate src_start(src_shape.size(), 0), src_end(src_shape.size(), 0); + src_start[0] = src_start_idx; + src_end[0] = src_start_idx + copy_length; + for (size_t d = 1; d < src_shape.size(); ++d) { + src_start[d] = 0; + src_end[d] = src_shape[d]; + } + ov::Tensor src_roi(src, src_start, src_end); + + // prepare destination ROI coords + const auto dst_shape = dst_base.get_shape(); + OPENVINO_ASSERT(!dst_shape.empty(), "destination tensor rank is zero"); + ov::Coordinate tgt_start(dst_shape.size(), 0), tgt_end(dst_shape.size(), 0); + tgt_start[0] = dst_first_dim_start; + tgt_end[0] = dst_first_dim_start + copy_length; + for (size_t d = 1; d < dst_shape.size(); ++d) { + tgt_start[d] = 0; + tgt_end[d] = dst_shape[d]; + } + ov::Tensor tgt_roi(dst_base, tgt_start, tgt_end); + + // bulk copy + src_roi.copy_to(tgt_roi); + } + // Fills indices for sequences in the order defined by scheduler_output void _fill_indices_from_block_tables( const std::vector& dst_tensor_names, From 8a480a43269ca78a51653d11970b908c7bccc61b Mon Sep 17 00:00:00 2001 From: fishbell Date: Tue, 23 Sep 2025 19:33:52 +0800 Subject: [PATCH 09/43] parse eagle info from draft model Signed-off-by: fishbell --- .../text_generation/eagle_speculative_lm.cpp | 25 +---- .../cpp/text_generation/greedy_causal_lm.cpp | 2 +- src/cpp/src/continuous_batching/pipeline.cpp | 93 +++++++++++++------ src/cpp/src/llm/pipeline.cpp | 21 +++++ .../speculative_decoding_impl.cpp | 24 +++-- .../speculative_decoding_impl.hpp | 7 +- .../continuous_batching_eagle_decoding.cpp | 2 +- .../continuous_batching_benchmark.cpp | 9 +- .../llm_bench/llm_bench_utils/model_utils.py | 1 - tools/llm_bench/llm_bench_utils/ov_utils.py | 8 -- tools/llm_bench/task/text_generation.py | 4 +- 11 files changed, 106 insertions(+), 90 deletions(-) diff --git a/samples/cpp/text_generation/eagle_speculative_lm.cpp b/samples/cpp/text_generation/eagle_speculative_lm.cpp index df74aaf59c..bbddb473d6 100644 --- a/samples/cpp/text_generation/eagle_speculative_lm.cpp +++ b/samples/cpp/text_generation/eagle_speculative_lm.cpp @@ -32,12 +32,12 @@ int main(int argc, char* argv[]) try { throw std::runtime_error(std::string{"Usage: "} + argv[0] + " ''"); } - std::string main_model_path = argv[1]; - std::string eagle_model_path = argv[2]; + std::filesystem::path main_model_path = argv[1]; + std::filesystem::path eagle_model_path = argv[2]; std::string prompt = argv[3]; // Configure devices - can run main and eagle models on different devices - std::string main_device = "GPU", eagle_device = "GPU"; // CPU can bse used as well + std::string main_device = "CPU", eagle_device = "CPU"; // GPU can be used as well // Eagle Speculative settings ov::genai::GenerationConfig config = ov::genai::greedy(); @@ -51,8 +51,7 @@ int main(int argc, char* argv[]) try { main_model_path, main_device, ov::genai::draft_model(eagle_model_path, eagle_device), - ov::genai::scheduler_config(scheduler_config), - ov::genai::eagle3_mode(true) + ov::genai::scheduler_config(scheduler_config) ); // Setup performance measurement auto start_time = std::chrono::high_resolution_clock::now(); @@ -85,22 +84,6 @@ int main(int argc, char* argv[]) try { print_perf_metrics(sd_perf_metrics->draft_model_metrics, "DRAFT MODEL"); } std::cout << std::endl; - - // Run without Eagle for comparison - std::cout << "\n-----------------------------" << std::endl; - std::cout << "Generating without Eagle Speculative decoding:" << std::endl; - - // Disable Eagle mode - /*config.eagle_model = false; - - start_time = std::chrono::high_resolution_clock::now(); - pipe.generate(prompt, config, streamer); - std::cout << std::endl; - */ - end_time = std::chrono::high_resolution_clock::now(); - duration = std::chrono::duration_cast(end_time - start_time); - std::cout << "\nStandard generation completed in " << duration.count() << " ms" << std::endl; - } catch (const std::exception& error) { try { std::cerr << error.what() << '\n'; diff --git a/samples/cpp/text_generation/greedy_causal_lm.cpp b/samples/cpp/text_generation/greedy_causal_lm.cpp index 341ba26d93..ca5e193da1 100644 --- a/samples/cpp/text_generation/greedy_causal_lm.cpp +++ b/samples/cpp/text_generation/greedy_causal_lm.cpp @@ -9,7 +9,7 @@ int main(int argc, char* argv[]) try { std::string models_path = argv[1]; std::string prompt = argv[2]; - std::string device = "GPU"; // GPU can be used as well + std::string device = "CPU"; // GPU can be used as well ov::genai::LLMPipeline pipe(models_path, device); ov::genai::GenerationConfig config; diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index 76cd3e4238..f8049277f2 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -29,15 +29,29 @@ extract_draft_model_from_config(ov::AnyMap& config) { } return draft_model; } +struct Eagle3RTInfo { + bool eagle3_mode = false; + std::vector hidden_layers_list; + std::filesystem::path dt_mapping_table; +}; -bool +Eagle3RTInfo extract_eagle_mode_from_config(ov::AnyMap& config) { - bool eagle_mode = false; - if (config.find(ov::genai::eagle3_mode.name()) != config.end()) { - eagle_mode = config.at(ov::genai::eagle3_mode.name()).as(); - config.erase(ov::genai::eagle3_mode.name()); + Eagle3RTInfo eagle_rt_info; + if (config.find("eagle3_mode") != config.end()) { + eagle_rt_info.eagle3_mode = config.at("eagle3_mode").as(); + config.erase("eagle3_mode"); + if (config.find("hidden_layers_list") != config.end()) { + eagle_rt_info.hidden_layers_list = config.at("hidden_layers_list").as>(); + config.erase("hidden_layers_list"); + } + if (config.find("dt_mapping_path") != config.end()) { + eagle_rt_info.dt_mapping_table = config.at("dt_mapping_path").as(); + eagle_rt_info.dt_mapping_table = eagle_rt_info.dt_mapping_table / "eagle3.safetensor"; + config.erase("dt_mapping_path"); + } } - return eagle_mode; + return eagle_rt_info; } bool @@ -66,7 +80,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_mode = extract_eagle_mode_from_config(properties_without_draft_model); + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties); auto model = utils::read_model(models_path, properties); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); @@ -82,18 +96,25 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); - } else if (draft_model_desr.model != nullptr && eagle_mode) { + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); // Eagle speculative decoding does not support dynamic_split_fuse mode // because it requires hidden state interaction from main model to draft model // to be implemented future - OPENVINO_ASSERT(scheduler_config.dynamic_split_fuse == false, "Eagle speculative decoding does not support dynamic_split_fuse mode"); - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); - m_impl = std::make_shared(main_model_descr, draft_model_desr); + SchedulerConfig scheduler_config_copy = scheduler_config; + if (scheduler_config.dynamic_split_fuse) { + std::cout << "WARNING: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; + scheduler_config_copy.dynamic_split_fuse = false; + // Use scheduler_config_copy in subsequent code if modification is needed + } + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config_copy, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); // parse d2t from safe tensors - ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(models_path / "eagle3.safetensor")); - if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional - std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { + ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); + if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional + std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + } } } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); @@ -119,7 +140,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_mode = extract_eagle_mode_from_config(properties_without_draft_model); + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties); auto model = utils::read_model(models_path, properties_without_draft_model); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); @@ -133,17 +154,22 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); - } else if (draft_model_desr.model != nullptr && eagle_mode) { + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); // Eagle speculative decoding does not support dynamic_split_fuse mode // because it requires hidden state interaction from main model to draft model // to be implemented future - OPENVINO_ASSERT(scheduler_config.dynamic_split_fuse == false, "Eagle speculative decoding does not support dynamic_split_fuse mode"); - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); - m_impl = std::make_shared(main_model_descr, draft_model_desr); + SchedulerConfig scheduler_config_copy = scheduler_config; + if (scheduler_config.dynamic_split_fuse) { + std::cout << "WARNING: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; + scheduler_config_copy.dynamic_split_fuse = false; + // Use scheduler_config_copy in subsequent code if modification is needed + } + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config_copy, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); // parse d2t from safe tensors - if (std::filesystem::exists(models_path / "eagle3.safetensor")) { - ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(models_path / "eagle3.safetensor")); + if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { + ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); } @@ -174,7 +200,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_mode = extract_eagle_mode_from_config(properties_without_draft_model); + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties); auto model = utils::singleton_core().read_model(model_str, weights_tensor); auto rt_info = model->get_rt_info(); @@ -192,18 +218,25 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model, generation_config); - } else if (draft_model_desr.model != nullptr && eagle_mode) { + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); // Eagle speculative decoding does not support dynamic_split_fuse mode // because it requires hidden state interaction from main model to draft model // to be implemented future - OPENVINO_ASSERT(scheduler_config.dynamic_split_fuse == false, "Eagle speculative decoding does not support dynamic_split_fuse mode"); - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); - m_impl = std::make_shared(main_model_descr, draft_model_desr); - if (eagle_mode) { - // parse d2t from safe tensors - //ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(models_path / "eagle3.safetensor")); - //std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + SchedulerConfig scheduler_config_copy = scheduler_config; + if (scheduler_config.dynamic_split_fuse) { + std::cout << "WARNING: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; + scheduler_config_copy.dynamic_split_fuse = false; + // Use scheduler_config_copy in subsequent code if modification is needed + } + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config_copy, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); + // parse d2t from safe tensors + if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { + ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); + if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional + std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + } } } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index 76d1fe24dc..441194364b 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -34,6 +34,25 @@ std::pair generation_config(const GenerationConfig& config) { return {utils::CONFIG_ARG_NAME, Any::make(config)}; } +inline void apply_eagle_rt_info(std::shared_ptr& model, ov::AnyMap& properties, const std::filesystem::path& mapping_path) { + if (model->has_rt_info("eagle3_mode") && model->get_rt_info("eagle3_mode")) { + properties["eagle3_mode"] = true; + if (!mapping_path.empty()) { + properties["dt_mapping_path"] = mapping_path; // d2t mapping path + } + } + + if (model->has_rt_info("hidden_layers_list")) { + properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); + } +} + +inline void apply_eagle_rt_info(std::shared_ptr& model, + ov::AnyMap& properties, + const std::string& mapping_path) { + apply_eagle_rt_info(model, properties, std::filesystem::path(mapping_path)); +} + std::pair draft_model( const std::filesystem::path& models_path, const std::string& device, @@ -42,6 +61,7 @@ std::pair draft_model( std::filesystem::path openvino_model_name = "openvino_model.xml"; auto model = utils::singleton_core().read_model(models_path / openvino_model_name, {}, plugin_config); + apply_eagle_rt_info(model, plugin_config, models_path); auto generation_config = utils::from_config_json_if_exists(models_path); auto tokenizer = ov::genai::Tokenizer(models_path); return { utils::DRAFT_MODEL_ARG_NAME, Any::make(model, tokenizer, device, plugin_config, scheduler_config, generation_config) }; @@ -57,6 +77,7 @@ std::pair draft_model( auto [plugin_config, scheduler_config] = utils::extract_scheduler_config(properties); auto model = utils::singleton_core().read_model(model_str, weights_tensor); + apply_eagle_rt_info(model, plugin_config, model_str); return { utils::DRAFT_MODEL_ARG_NAME, Any::make(model, tokenizer, device, plugin_config, scheduler_config, generation_config) }; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 807cf8f5c0..e96a5739f7 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -366,23 +366,20 @@ std::vector ContinuousBatchingPipeline::SpeculativeDecodingI return main_awaiting_requests; } - void extract_hidden_state_generic(std::shared_ptr& model, - const std::string& model_type, - const std::string& custom_node_name = "") { - if (model_type == "draft") { // for draft model, we always only need to extract last hidden state - std::cout << model_type << " model - last hidden state extraction" << std::endl; + const std::vector& hidden_layers_to_abstract) { + if (hidden_layers_to_abstract.size() == 1 && + hidden_layers_to_abstract[0] == -1) { // for draft model, we always only need to extract last hidden state + std::cout << "draft model - last hidden state extraction" << std::endl; ov::pass::Manager pm; - std::vector layers = {-1}; // -1 means last hidden layer - pm.register_pass(layers); + pm.register_pass(hidden_layers_to_abstract); pm.run_passes(model); } else { - std::cout << model_type << " model - Eagle 3 hidden state extraction" << std::endl; + std::cout << "main model - Eagle 3 hidden states extraction" << std::endl; ov::pass::Manager pm; /*if idx==len(self.layers)-3 or idx==len(self.layers)//2 or idx==2: all_hidden_states += (hidden_states,)*/ - std::vector layers = {2, 16, 29}; // need to add check, only support positive values - pm.register_pass(layers); + pm.register_pass(hidden_layers_to_abstract); pm.run_passes(model); } } @@ -657,7 +654,8 @@ bool Eagle3Transform::apply(NodePtr node, std::vector>& hidden_stat } ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, - const ov::genai::ModelDesc& draft_model_desc) { + const ov::genai::ModelDesc& draft_model_desc, + const std::vector& hidden_layers) { auto main_model = main_model_desc.model; auto draft_model = draft_model_desc.model; @@ -726,8 +724,8 @@ ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai // apply transformations needed to run eagle model // target model: hidden state extraction, draft model: hidden state import , hidden state extraction // eagle3 specific : dt importing - extract_hidden_state_generic(main_model, "main", ""); - extract_hidden_state_generic(draft_model, "draft", ""); + extract_hidden_state_generic(main_model, hidden_layers); + extract_hidden_state_generic(draft_model, { -1 }); // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode m_main_pipeline = std::make_shared(main_model, diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 571eb0d9b0..06606dd331 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -89,7 +89,7 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingPipeline::SpeculativeDecodingImpl { public: - EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); + EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc, const std::vector& hidden_layers_to_abstract); std::vector generate(const std::vector& input_ids, @@ -105,9 +105,7 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP GenerationHandle add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) override; - void fill_hidden_states(const ov::Tensor& hidden_states) { - hiddenstates_tensor = hidden_states; - } + void set_d2t_for_draft_decoding(std::shared_ptr& d2t_tensor) { auto eagle_impl = std::dynamic_pointer_cast(m_draft_pipeline); eagle_impl->set_d2t_for_draft_decoding(d2t_tensor); @@ -115,7 +113,6 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP protected: void update_eagle_pipeline_params(); ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); - ov::Tensor hiddenstates_tensor; }; using NodePtr = std::shared_ptr; diff --git a/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp b/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp index ccc72c2116..395099cc43 100644 --- a/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp +++ b/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp @@ -106,7 +106,7 @@ int main(int argc, char* argv[]) try { // vLLM specific params scheduler_config.max_num_seqs = 2; - ov::genai::ContinuousBatchingPipeline pipe(models_path, scheduler_config, device, {ov::genai::draft_model(draft_models_path, device), ov::genai::eagle3_mode(true)}); + ov::genai::ContinuousBatchingPipeline pipe(models_path, scheduler_config, device, {ov::genai::draft_model(draft_models_path, device)}); std::vector generation_results = pipe.generate(prompts, cb_generation_config); for (size_t request_id = 0; request_id < generation_results.size(); ++request_id) { diff --git a/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp b/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp index 6810709486..afb6dd4265 100644 --- a/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp +++ b/tools/continuous_batching/benchmark/continuous_batching_benchmark.cpp @@ -437,7 +437,6 @@ int main(int argc, char* argv[]) try { ("dynamic_split_fuse", "Whether to use dynamic split-fuse or vLLM scheduling", cxxopts::value()->default_value("true")) ("m,model", "Path to model and tokenizers base directory", cxxopts::value()->default_value(".")) ("draft_model", "Path to assistant model directory", cxxopts::value()->default_value("")) - ("eagle3_mode", "Whether to enable eagle3 mode for speculative decoding", cxxopts::value()->default_value("false")) ("dataset", "Path to dataset .json file", cxxopts::value()->default_value("./ShareGPT_V3_unfiltered_cleaned_split.json")) ("max_input_len", "Max input length take from dataset", cxxopts::value()->default_value("1024")) ("max_output_len", "Max output length", cxxopts::value()->default_value("2048")) @@ -477,7 +476,7 @@ int main(int argc, char* argv[]) try { const bool use_cache_eviction = result["use_cache_eviction"].as(); bool is_speculative_decoding_enabled = !draft_model_path.empty(); - bool is_eagle3_mode = result["eagle3_mode"].as(); + // Create requests for generation Dataset dataset = filtered_dataset(models_path, dataset_path, num_prompts, max_input_len, max_output_len); @@ -508,12 +507,6 @@ int main(int argc, char* argv[]) try { ov::AnyMap device_config_map = {}; if (is_speculative_decoding_enabled) { device_config_map.insert({ ov::genai::draft_model(draft_model_path) }); - if (is_eagle3_mode) { - // disable dynamic split fuse in eagle3 mode - scheduler_config.dynamic_split_fuse = false; - std::cout << "disable dynamic split fuse for eagle3 speculative decoding" << std::endl; - device_config_map.insert({ ov::genai::eagle3_mode(true) }); - } } if (!parse_plugin_config_string(device_config, device_config_map)) { std::cout << "ERROR: Wrong json parameter in device_config." << std::endl; diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 490d68e9d7..6658b7d8fc 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -212,7 +212,6 @@ def analyze_args(args): draft_cb_config = get_config(args.draft_cb_config) model_args["draft_cb_config"] = draft_cb_config model_args['num_assistant_tokens'] = args.num_assistant_tokens - model_args['eagle3_mode'] = args.eagle3_mode model_args['assistant_confidence_threshold'] = args.assistant_confidence_threshold model_args['max_ngram_size'] = args.max_ngram_size diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index 8ec13fa6eb..9fdb82567e 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -240,14 +240,6 @@ def create_genai_text_gen_model(model_path, device, ov_config, memory_monitor, * draft_model_load_kwargs = {'scheduler_config': get_scheduler_config_genai(kwargs.get("draft_cb_config"), config_name="draft CB config")}\ if kwargs.get("draft_cb_config") is not None else {} config['draft_model'] = openvino_genai.draft_model(draft_model_path, draft_device.upper(), **draft_model_load_kwargs) - if (kwargs.get('eagle3_mode', None)): - config['eagle3_mode'] = True - if 'scheduler_config' in config.keys(): - config['scheduler_config'].dynamic_split_fuse = False # Eagle speculative decoding does not support dynamic_split_fuse mode - else: - config['scheduler_config'] = openvino_genai.SchedulerConfig() - config['scheduler_config'].dynamic_split_fuse = False - log.info("Eagle3 Speculative Decoding is activated, and dynamic_split_fuse is set to False in scheduler_config") if kwargs.get('max_ngram_size') and kwargs.get('num_assistant_tokens'): log.info("Prompt Lookup decoding is activated") diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index 55cb76bdba..bf9c9cfff8 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -284,7 +284,7 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data ) tokenization_start = time.perf_counter() - if (args.get("eagle3_mode")): + if args.get("eagle3_mode"): # eagle3 needs to disable special tokens to ensure compress rate input_data = tokenizer.encode(input_text_list, add_special_tokens=False) else: @@ -455,7 +455,7 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg "If it is not expected, please specify --disable_prompt_permutation in your benchmarking command to disable this behavior" ) tok_encode_start = time.perf_counter() - if (args.get("eagle3_mode")): + if args.get("eagle3_mode"): # eagle3 needs to disable special tokens to ensure compress rate input_data = pipe_tokenizer.encode(input_text_list, add_special_tokens=False) else: From cd753f693faff13e125bca9b830ee0369fc87008 Mon Sep 17 00:00:00 2001 From: fishbell Date: Tue, 23 Sep 2025 19:40:50 +0800 Subject: [PATCH 10/43] do not need seperate eagle sample Signed-off-by: fishbell --- samples/cpp/text_generation/CMakeLists.txt | 3 +- .../text_generation/eagle_speculative_lm.cpp | 97 ------------------- .../text_generation/eagle_speculative_lm.py | 67 ------------- .../include/openvino/genai/llm_pipeline.hpp | 6 -- .../speculative_decoding_impl.cpp | 70 +++---------- .../speculative_decoding_impl.hpp | 4 - ...ntinuous_batching_speculative_decoding.cpp | 11 +-- 7 files changed, 16 insertions(+), 242 deletions(-) delete mode 100644 samples/cpp/text_generation/eagle_speculative_lm.cpp delete mode 100755 samples/python/text_generation/eagle_speculative_lm.py diff --git a/samples/cpp/text_generation/CMakeLists.txt b/samples/cpp/text_generation/CMakeLists.txt index 0fa9d9eb6d..ebaf32c7f4 100644 --- a/samples/cpp/text_generation/CMakeLists.txt +++ b/samples/cpp/text_generation/CMakeLists.txt @@ -29,8 +29,7 @@ set (SAMPLE_LIST lora_greedy_causal_lm multinomial_causal_lm prompt_lookup_decoding_lm - speculative_decoding_lm - eagle_speculative_lm) + speculative_decoding_lm) foreach(sample IN LISTS SAMPLE_LIST) add_sample_executable(${sample}) diff --git a/samples/cpp/text_generation/eagle_speculative_lm.cpp b/samples/cpp/text_generation/eagle_speculative_lm.cpp deleted file mode 100644 index bbddb473d6..0000000000 --- a/samples/cpp/text_generation/eagle_speculative_lm.cpp +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (C) 2023-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include -#include - -#include "openvino/genai/llm_pipeline.hpp" -#include "openvino/genai/speculative_decoding/perf_metrics.hpp" - -template -void print_perf_metrics(T& perf_metrics, std::string model_name) { - std::cout << "\n" << model_name << std::endl; - auto generation_duration = perf_metrics.get_generate_duration().mean; - std::cout << " Generate time: " << generation_duration << " ms" << std::endl; - std::cout << " TTFT: " << perf_metrics.get_ttft().mean << " ± " << perf_metrics.get_ttft().std << " ms" - << std::endl; - std::cout << " TPOT: " << perf_metrics.get_tpot().mean << " ± " << perf_metrics.get_tpot().std << " ms/token" - << std::endl; - std::cout << " Num generated token: " << perf_metrics.get_num_generated_tokens() << " tokens" << std::endl; - if (model_name == "Total") { - std::cout << " Total iteration number: " << perf_metrics.raw_metrics.m_new_token_times.size() << std::endl; - } else { - std::cout << " Total iteration number: " << perf_metrics.raw_metrics.m_durations.size() << std::endl; - } - if (perf_metrics.get_num_input_tokens() > 0) { - std::cout << " Input token size: " << perf_metrics.get_num_input_tokens() << std::endl; - } -} - -int main(int argc, char* argv[]) try { - if (4 != argc) { - throw std::runtime_error(std::string{"Usage: "} + argv[0] + " ''"); - } - - std::filesystem::path main_model_path = argv[1]; - std::filesystem::path eagle_model_path = argv[2]; - std::string prompt = argv[3]; - - // Configure devices - can run main and eagle models on different devices - std::string main_device = "CPU", eagle_device = "CPU"; // GPU can be used as well - - // Eagle Speculative settings - ov::genai::GenerationConfig config = ov::genai::greedy(); - config.max_new_tokens = 100; - config.num_assistant_tokens = 5; - - ov::genai::SchedulerConfig scheduler_config; - scheduler_config.dynamic_split_fuse = false; // Eagle speculative decoding does not support dynamic_split_fuse mode - // Create pipeline with eagle speculative enabled - ov::genai::LLMPipeline pipe( - main_model_path, - main_device, - ov::genai::draft_model(eagle_model_path, eagle_device), - ov::genai::scheduler_config(scheduler_config) - ); - // Setup performance measurement - auto start_time = std::chrono::high_resolution_clock::now(); - - // Optional: Create a streaming callback for real-time token display - auto streamer = [](std::string subword) { - std::cout << subword << std::flush; - return ov::genai::StreamingStatus::RUNNING; - }; - - // Run generation with eagle speculative decoding - std::cout << "Generating with Eagle Speculative decoding:" << std::endl; - auto result = pipe.generate(prompt, config, streamer); - std::cout << std::endl; - - // Calculate and display performance metrics - auto end_time = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end_time - start_time); - std::cout << "\nGeneration completed in " << duration.count() << " ms" << std::endl; - - auto sd_perf_metrics = std::dynamic_pointer_cast(result.extended_perf_metrics); - if (sd_perf_metrics) { - print_perf_metrics(result.perf_metrics, "Total"); - print_perf_metrics(sd_perf_metrics->main_model_metrics, "MAIN MODEL"); - std::cout << " accepted token: " << sd_perf_metrics->get_num_accepted_tokens() << " tokens" << std::endl; - std::cout << " compress rate: " - << sd_perf_metrics->main_model_metrics.get_num_generated_tokens() * 1.0f / - sd_perf_metrics->main_model_metrics.raw_metrics.m_durations.size() - << std::endl; - print_perf_metrics(sd_perf_metrics->draft_model_metrics, "DRAFT MODEL"); - } - std::cout << std::endl; -} catch (const std::exception& error) { - try { - std::cerr << error.what() << '\n'; - } catch (const std::ios_base::failure&) {} - return EXIT_FAILURE; -} catch (...) { - try { - std::cerr << "Non-exception object thrown\n"; - } catch (const std::ios_base::failure&) {} - return EXIT_FAILURE; -} \ No newline at end of file diff --git a/samples/python/text_generation/eagle_speculative_lm.py b/samples/python/text_generation/eagle_speculative_lm.py deleted file mode 100755 index 2ce64446be..0000000000 --- a/samples/python/text_generation/eagle_speculative_lm.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import argparse -import openvino_genai -import queue - -def streamer(subword): - print(subword, end='', flush=True) - # Return flag corresponds whether generation should be stopped. - return openvino_genai.StreamingStatus.RUNNING - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('model_dir') - parser.add_argument('draft_model_dir') - parser.add_argument('prompt') - args = parser.parse_args() - - # User can run main and draft model on different devices. - # Please, set device for main model in `openvino_genai.LLMPipeline` constructor and in openvino_genai.draft_model` for draft. - main_device = 'GPU' # CPU can be used as well - draft_device = 'GPU' - scheduler_config = openvino_genai.SchedulerConfig() - scheduler_config.dynamic_split_fuse = False # Eagle speculative decoding does not support dynamic_split_fuse mode - draft_model = openvino_genai.draft_model(args.draft_model_dir, draft_device) - - pipe = openvino_genai.LLMPipeline(args.model_dir, main_device, scheduler_config = scheduler_config, draft_model=draft_model, eagle3_mode = True) - - config = openvino_genai.GenerationConfig() - config.max_new_tokens = 100 - # Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded - # add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration - config.num_assistant_tokens = 5 - # add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold` - # config.assistant_confidence_threshold = 0.4 - - # Since the streamer is set, the results will be printed - # every time a new token is generated and put into the streamer queue. - res = pipe.generate([args.prompt], config, streamer) - print() - if (res.extended_perf_metrics): - main_model_metrics = res.extended_perf_metrics.main_model_metrics - print(f"MAIN MODEL") - print(f" Generate time: {main_model_metrics.get_generate_duration().mean:.2f} ms" ) - print(f" TTFT: {main_model_metrics.get_ttft().mean:.2f} ± {main_model_metrics.get_ttft().std:.2f} ms" ) - print(f" TTST: {main_model_metrics.get_ttst().mean:.2f} ± {main_model_metrics.get_ttst().std:.2f} ms/token") - print(f" TPOT: {main_model_metrics.get_tpot().mean:.2f} ± {main_model_metrics.get_tpot().std:.2f} ms/iteration") - print(f" AVG Latency: {main_model_metrics.get_latency().mean:.2f} ± {main_model_metrics.get_latency().std:.2f} ms/token") - print(f" Num generated token: {main_model_metrics.get_num_generated_tokens()} tokens") - print(f" Total iteration number: {len(main_model_metrics.raw_metrics.m_durations)}") - print(f" Num accepted token: {res.extended_perf_metrics.get_num_accepted_tokens()} tokens") - - draft_model_metrics = res.extended_perf_metrics.draft_model_metrics - print(f"DRAFT MODEL" ) - print(f" Generate time: {draft_model_metrics.get_generate_duration().mean:.2f} ms" ) - print(f" TTFT: {draft_model_metrics.get_ttft().mean:.2f} ms") - print(f" TTST: {draft_model_metrics.get_ttst().mean:.2f} ms/token") - print(f" TPOT: {draft_model_metrics.get_tpot().mean:.2f} ± {draft_model_metrics.get_tpot().std:.2f} ms/token") - print(f" AVG Latency: {draft_model_metrics.get_latency().mean:.2f} ± {draft_model_metrics.get_latency().std:.2f} ms/iteration") - print(f" Num generated token: {draft_model_metrics.get_num_generated_tokens()} tokens") - print(f" Total iteration number: {len(draft_model_metrics.raw_metrics.m_durations)}") - print() - -if '__main__' == __name__: - main() diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index f99c76b3c1..eea94591c3 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -345,12 +345,6 @@ static constexpr ov::Property prompt_lookup{"prompt_lookup"}; */ static constexpr ov::Property enable_save_ov_model{"enable_save_ov_model"}; -/** -* @brief enable eagle3_mode property serves to activate eagle3 speculative decoding. -* Set `true` to activate this mode. -* And create LLMPipeline instance with this config. -*/ -static constexpr ov::Property eagle3_mode{"eagle3_mode"}; } // namespace genai } // namespace ov diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index e96a5739f7..0ab7b35373 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -368,20 +368,9 @@ std::vector ContinuousBatchingPipeline::SpeculativeDecodingI void extract_hidden_state_generic(std::shared_ptr& model, const std::vector& hidden_layers_to_abstract) { - if (hidden_layers_to_abstract.size() == 1 && - hidden_layers_to_abstract[0] == -1) { // for draft model, we always only need to extract last hidden state - std::cout << "draft model - last hidden state extraction" << std::endl; - ov::pass::Manager pm; - pm.register_pass(hidden_layers_to_abstract); - pm.run_passes(model); - } else { - std::cout << "main model - Eagle 3 hidden states extraction" << std::endl; - ov::pass::Manager pm; - /*if idx==len(self.layers)-3 or idx==len(self.layers)//2 or idx==2: - all_hidden_states += (hidden_states,)*/ - pm.register_pass(hidden_layers_to_abstract); - pm.run_passes(model); - } + ov::pass::Manager pm; + pm.register_pass(hidden_layers_to_abstract); + pm.run_passes(model); } EagleModelTransform::EagleModelTransform(const std::vector& layers) : m_layer_ids(layers) { @@ -445,7 +434,7 @@ EagleInputTransform::EagleInputTransform(std::vector>& params) { if (ov::is_type(node)) { auto matmul_node = ov::as_type_ptr(node); - if (matmul_node->get_friendly_name().find("__module.model.fc/ov_ext::linear/MatMul") == std::string::npos) { // hardcode for now + // check the input of matmul node, if it is a node with name "target_hidden_state_input", then it's the node we want + auto input_node = matmul_node->get_input_node_shared_ptr(0); + if (!ov::as_type_ptr(input_node)) { return false; } + auto shape = node->get_output_partial_shape(0); auto internal_hidden_state = std::make_shared(node->get_element_type(), node->get_output_partial_shape(0)); internal_hidden_state->output(0).set_names({"internal_hidden_state_input"}); internal_hidden_state->set_friendly_name("internal_hidden_state_input"); - // create new eltwise node to add output of MatMul node and + // create new eltwise node to add output of MatMul node and internal hidden state input from last cycle of itself auto new_eltwise = std::make_shared(internal_hidden_state, matmul_node->output(0)); ov::replace_node(matmul_node, new_eltwise); params.push_back(internal_hidden_state); @@ -481,7 +473,7 @@ EagleBaseTransform::EagleBaseTransform(std::vector EagleBaseTransform::find_last_hidden_node(const std::shared_ptr& start_node, - std::set& visited_nodes) { - if (visited_nodes.count(start_node.get())) { - return nullptr; - } - - visited_nodes.insert(start_node.get()); - - if (ov::is_type(start_node)) { - // check the input nodes of MatMul, if found Gather node, return the gather node, otherwise ,retrun the matmul node - for (size_t i = 0; i < start_node->get_input_size(); ++i) { - auto input_node = start_node->get_input_node_shared_ptr(i); - if (!input_node) continue; - // rule out logit processing node - if (ov::as_type_ptr(input_node)) { - return input_node; - } - } - return start_node; - } - - for (size_t i = 0; i < start_node->get_input_size(); ++i) { - auto input_node = start_node->get_input_node_shared_ptr(i); - if (!input_node) continue; - - auto result = find_last_hidden_node(input_node, visited_nodes); - if (result) { - return result; - } - } - return nullptr; -} - std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node, std::set& visited_nodes) { if (visited_nodes.count(start_node.get())) { @@ -556,11 +515,6 @@ std::shared_ptr EagleBaseTransform::find_last_residual_node(const std: return nullptr; } -std::shared_ptr EagleBaseTransform::find_last_hidden_node(const std::shared_ptr& start_node) { - std::set visited_nodes; - return find_last_hidden_node(start_node, visited_nodes); -} - std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node) { std::set visited_nodes; return find_last_residual_node(start_node, visited_nodes); @@ -568,7 +522,7 @@ std::shared_ptr EagleBaseTransform::find_last_residual_node(const std: bool EagleBaseTransform::apply(NodePtr node, std::vector>& params, std::vector>& results) { { - // 1. with normalization layer 2. add extra input + // 1. without normalization layer 2. add extra input if (ov::is_type(node)) { // we are applying transformation to the last hidden state, eagle2 mode NodePtr input_node = node->get_input_node_shared_ptr(0); @@ -590,9 +544,7 @@ bool EagleBaseTransform::apply(NodePtr node, std::vector& layers, std::vector>& hidden_state_outputs) : m_layers(layers) { auto is_target_pattern = [&](const Output& output) { auto add_node = ov::as_type_ptr(output.get_node_shared_ptr()); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 06606dd331..79e4f8104c 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -129,10 +129,6 @@ class EagleBaseTransform : public ov::pass::MatcherPass { private: bool apply(NodePtr node, std::vector>& params, std::vector>& results); size_t applied = 0; - std::string m_eagle_version; - std::shared_ptr find_last_hidden_node(const std::shared_ptr& start_node); - std::shared_ptr find_last_hidden_node(const std::shared_ptr& start_node, - std::set& visited_nodes); std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node); std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node, std::set& visited_nodes); diff --git a/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp b/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp index 3a36dcb2c6..0b0797ddf9 100644 --- a/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp +++ b/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp @@ -20,11 +20,10 @@ std::vector get_spec_decoding_generation_config_exa generation_config_greedy_constant.num_assistant_tokens = 5; } - ov::genai::GenerationConfig generation_config_multinomial_constant =ov::genai::greedy(); + ov::genai::GenerationConfig generation_config_multinomial_constant = ov::genai::multinomial(); { - //generation_config_multinomial_constant.num_assistant_tokens = 5; - generation_config_multinomial_constant.num_return_sequences = 1; generation_config_multinomial_constant.num_assistant_tokens = 5; + generation_config_multinomial_constant.num_return_sequences = 1; } ov::genai::GenerationConfig generation_config_greedy_dynamic = ov::genai::greedy(); @@ -32,11 +31,9 @@ std::vector get_spec_decoding_generation_config_exa generation_config_greedy_dynamic.assistant_confidence_threshold = 0.8f; } - ov::genai::GenerationConfig generation_config_multinomial_dynamic = ov::genai::greedy(); + ov::genai::GenerationConfig generation_config_multinomial_dynamic = ov::genai::multinomial(); { - //generation_config_multinomial_dynamic.assistant_confidence_threshold = 0.8f; - generation_config_multinomial_dynamic.num_return_sequences = 1; - generation_config_multinomial_dynamic.num_assistant_tokens = 5; + generation_config_multinomial_dynamic.assistant_confidence_threshold = 0.8f; } return { From 6aa49659e483468e91db3b9c6320ff7698d32cd9 Mon Sep 17 00:00:00 2001 From: fishbell Date: Tue, 23 Sep 2025 23:20:20 +0800 Subject: [PATCH 11/43] opt constructor for eagledecodingimpl Signed-off-by: fishbell --- src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 79e4f8104c..3903c0abaa 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -61,7 +61,7 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat void drop_requests(); bool is_requests_empty(); std::vector get_awaiting_requests(); - + std::pair init_speculative_models(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); public: SpeculativeDecodingImpl() = default; SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); From 5efbada713c2e356279f466774929724489ba1c1 Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 24 Sep 2025 00:08:04 +0800 Subject: [PATCH 12/43] remove hardcoding of eagle layers Signed-off-by: fishbell --- .../src/continuous_batching/model_runner.hpp | 11 +- ...batching_for_speculative_decoding_impl.hpp | 6 + .../speculative_decoding_impl.cpp | 103 +++++------------- .../speculative_decoding_impl.hpp | 1 + 4 files changed, 43 insertions(+), 78 deletions(-) diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index 1d85bb6b62..35457e892d 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -55,6 +55,7 @@ class ModelRunner { std::map, std::pair> m_sequence_hidden_state_mapping; // pre-requisite: main/draft have same seq group and running seq grouped id // a container which use sequence group id and request id as key to store hidden states std::map m_initial_hidden_states; // shape: [N, seq_len, hidden_size] + size_t m_adjust_factor = 1; // to adjust the hidden size of draft model input public: /** * Constructs the ModelRunner. @@ -117,6 +118,10 @@ class ModelRunner { m_embedding = embedder; } + void set_adjust_factor(size_t adjust_factor) { + m_adjust_factor = adjust_factor; + } + /** * @return A map of sequence IDs to vectors of ov::Tensor per-token attention scores. Each vector element is associated with its own * decoder layer, in order of their execution in the model. Each ov::Tensor has a shape of {N_k}, where N_k is the length of @@ -235,7 +240,7 @@ class ModelRunner { if (shape.size() >= 2) { hidden_size = shape[shape.size() - 1]; if (!m_is_hidden_state_import_needed) - hidden_size /= 3; + hidden_size /= m_adjust_factor; break; } } @@ -476,7 +481,7 @@ class ModelRunner { try { m_request.set_tensor("target_hidden_state_input", hidden_state_input); auto shape = hidden_state_input.get_shape(); - shape[shape.size() - 1] = shape [shape.size() - 1] / 3; + shape[shape.size() - 1] = shape [shape.size() - 1] / m_adjust_factor; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); @@ -487,7 +492,7 @@ class ModelRunner { try { m_request.set_tensor("internal_hidden_state_input", hidden_state_input); auto shape = hidden_state_input.get_shape(); - shape[shape.size() - 1] = shape [shape.size() - 1] * 3; + shape[shape.size() - 1] = shape [shape.size() - 1] * m_adjust_factor; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index 69ee41c501..9f7e8ae652 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -83,5 +83,11 @@ class ContinuousBatchingPipeline::ContinuousBatchingForEagleDecodingImpl m_model_runner->set_hidden_state_internal_needed(is_needed); } } + + void set_adjust_factor(size_t adjust_factor) { + if (m_model_runner) { + m_model_runner->set_adjust_factor(adjust_factor); + } + } }; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 0ab7b35373..86914886b8 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -25,23 +25,17 @@ bool are_tokenizers_equal(Tokenizer& lhs, Tokenizer& rhs) { lhs.get_bos_token_id() == rhs.get_bos_token_id() && lhs.get_pad_token_id() == rhs.get_pad_token_id(); } -ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, - const ov::genai::ModelDesc& draft_model_desc) { - auto main_model = main_model_desc.model; - auto draft_model = draft_model_desc.model; - - auto main_scheduler_config = main_model_desc.scheduler_config; - auto main_device = main_model_desc.device; - - utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction); - utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction); +std::pair +ContinuousBatchingPipeline::SpeculativeDecodingImpl::init_speculative_models(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc) { + utils::apply_paged_attention_transformations(main_model_desc.model, main_model_desc.scheduler_config.use_cache_eviction); + utils::apply_paged_attention_transformations(draft_model_desc.model, main_model_desc.scheduler_config.use_cache_eviction); - utils::apply_gather_before_matmul_transformation(main_model); - utils::apply_gather_before_matmul_transformation(draft_model); + utils::apply_gather_before_matmul_transformation(main_model_desc.model); + utils::apply_gather_before_matmul_transformation(draft_model_desc.model); - std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; bool is_draft_scheduler_undefined = draft_model_desc.scheduler_config == SchedulerConfig(); + auto main_scheduler_config = main_model_desc.scheduler_config; ov::genai::SchedulerConfig main_scheduler_config_updated = main_scheduler_config, draft_scheduler_config = is_draft_scheduler_undefined ? main_scheduler_config : draft_model_desc.scheduler_config; @@ -60,8 +54,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con } return total_hidden_size; }; - float main_model_hidden_size = compute_total_hidden_size(main_model), - draft_model_hidden_size = compute_total_hidden_size(draft_model); + float main_model_hidden_size = compute_total_hidden_size(main_model_desc.model), + draft_model_hidden_size = compute_total_hidden_size(draft_model_desc.model); auto k = draft_model_hidden_size / (main_model_hidden_size + draft_model_hidden_size); // TODO: work with KV blocks as it will be more precise instead of GBs @@ -79,8 +73,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con draft_scheduler_config.max_num_batched_tokens = main_scheduler_config_updated.max_num_batched_tokens; } - ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; + return std::make_pair(main_scheduler_config_updated, draft_scheduler_config); +} +ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc) { + auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); + + auto main_device = main_model_desc.device; + std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; // main and draft model can have different tokenizers // to do: support retokenization: 154103 Tokenizer main_model_tokenizer = main_model_desc.tokenizer; @@ -88,16 +89,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con // todo: remove this condition after support of CVS-154103 OPENVINO_ASSERT(are_tokenizers_equal(main_model_tokenizer, draft_model_tokenizer), "Tokenizers for draft and main models are different!"); - m_tokenizer = main_model_tokenizer; - + ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode m_main_pipeline = std::make_shared( - main_model, main_model_tokenizer, main_model_desc.generation_config, - main_scheduler_config_updated, main_device, main_model_desc.properties, true); + main_model_desc.model, main_model_tokenizer, main_model_desc.generation_config, + scheduler_configs.first, main_device, main_model_desc.properties, true); m_draft_pipeline = std::make_shared( - draft_model, draft_model_tokenizer, draft_model_desc.generation_config, - draft_scheduler_config, draft_device, draft_properties, false); + draft_model_desc.model, draft_model_tokenizer, draft_model_desc.generation_config, + scheduler_configs.second, draft_device, draft_properties, false); m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; @@ -607,57 +607,14 @@ bool Eagle3Transform::apply(NodePtr node, std::vector>& hidden_stat ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc, - const std::vector& hidden_layers) { + const std::vector& hidden_layers) + : m_hidden_layers_to_abstract(hidden_layers) { + auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); auto main_model = main_model_desc.model; auto draft_model = draft_model_desc.model; - auto main_scheduler_config = main_model_desc.scheduler_config; auto main_device = main_model_desc.device; - - utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction); - utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction); - - utils::apply_gather_before_matmul_transformation(main_model); - utils::apply_gather_before_matmul_transformation(draft_model); - std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; - bool is_draft_scheduler_undefined = draft_model_desc.scheduler_config == SchedulerConfig(); - - ov::genai::SchedulerConfig main_scheduler_config_updated = main_scheduler_config, - draft_scheduler_config = is_draft_scheduler_undefined - ? main_scheduler_config - : draft_model_desc.scheduler_config; - - if (is_draft_scheduler_undefined) { - // split KV cache to 2 caches for main and draft models - auto compute_total_hidden_size = [](const std::shared_ptr& model) -> size_t { - size_t total_hidden_size = 0; - for (const auto& param_ptr : model->get_parameters()) { - const auto& name = param_ptr->get_friendly_name(); - if (name.find("key_cache.") == 0) { - auto pa_op = param_ptr->get_output_target_inputs(0).begin()->get_node(); - const auto& rt_info = pa_op->get_rt_info(); - total_hidden_size += rt_info.at("num_k_heads").as() * rt_info.at("k_head_size").as() + - rt_info.at("num_v_heads").as() * rt_info.at("v_head_size").as(); - } - } - return total_hidden_size; - }; - float main_model_hidden_size = compute_total_hidden_size(main_model), - draft_model_hidden_size = compute_total_hidden_size(draft_model); - auto k = draft_model_hidden_size / (main_model_hidden_size + draft_model_hidden_size); - - // TODO: work with KV blocks as it will be more precise instead of GBs - size_t main_cache_size = std::ceil(main_scheduler_config.cache_size * (1.f - k)), - draft_cache_size = main_scheduler_config.cache_size - main_cache_size; - if (draft_cache_size == 0 && main_cache_size > 0) { - main_cache_size -= (main_cache_size > 1 ? 1 : 0); - draft_cache_size = 1; - } - - main_scheduler_config_updated.cache_size = main_cache_size; - draft_scheduler_config.cache_size = draft_cache_size; - } ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; @@ -666,11 +623,6 @@ ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai // to do: support retokenization: 154103 Tokenizer main_model_tokenizer = main_model_desc.tokenizer; Tokenizer draft_model_tokenizer = draft_model_desc.tokenizer; - - // todo: remove this condition after support of CVS-154103 - OPENVINO_ASSERT(are_tokenizers_equal(main_model_tokenizer, draft_model_tokenizer), - "Tokenizers for draft and main models are different!"); - m_tokenizer = main_model_tokenizer; // for eagle model, we need to obtain hidden layer state as extra output // apply transformations needed to run eagle model @@ -683,14 +635,14 @@ ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai m_main_pipeline = std::make_shared(main_model, main_model_tokenizer, main_model_desc.generation_config, - main_scheduler_config_updated, + scheduler_configs.first, main_device, main_model_desc.properties, true); m_draft_pipeline = std::make_shared(draft_model, draft_model_tokenizer, draft_model_desc.generation_config, - draft_scheduler_config, + scheduler_configs.second, draft_device, draft_properties, false); @@ -725,6 +677,7 @@ void ContinuousBatchingPipeline::EagleDecodingImpl::update_eagle_pipeline_params m_draft_eagle_pipeline->set_hidden_state_export_needed(true); m_draft_eagle_pipeline->set_hidden_state_import_needed(true); m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); + m_draft_eagle_pipeline->set_adjust_factor(m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); } GenerationHandle diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 3903c0abaa..8a2afc0fa6 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -113,6 +113,7 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP protected: void update_eagle_pipeline_params(); ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); + std::vector m_hidden_layers_to_abstract; }; using NodePtr = std::shared_ptr; From 96276fec68c53a2c2a9a063e83823f45de8522da Mon Sep 17 00:00:00 2001 From: fishbell Date: Tue, 30 Sep 2025 19:44:46 +0800 Subject: [PATCH 13/43] apply copilot comment Signed-off-by: fishbell --- src/cpp/src/continuous_batching/model_runner.hpp | 6 ++---- src/cpp/src/continuous_batching/pipeline.cpp | 6 +++--- src/cpp/src/llm/pipeline.cpp | 7 +++---- .../src/speculative_decoding/speculative_decoding_impl.cpp | 6 ------ 4 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index 35457e892d..8baa4fb056 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -180,8 +180,6 @@ class ModelRunner { } void set_initial_hidden_state(size_t request_id, const ov::Tensor& hidden_state) { - // m_initial_hidden_states.clear(); - //auto key = std::make_pair(request_id, seq_grouped_id); m_initial_hidden_states[request_id] = hidden_state; } @@ -481,7 +479,7 @@ class ModelRunner { try { m_request.set_tensor("target_hidden_state_input", hidden_state_input); auto shape = hidden_state_input.get_shape(); - shape[shape.size() - 1] = shape [shape.size() - 1] / m_adjust_factor; + shape[shape.size() - 1] = shape[shape.size() - 1] / m_adjust_factor; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); @@ -492,7 +490,7 @@ class ModelRunner { try { m_request.set_tensor("internal_hidden_state_input", hidden_state_input); auto shape = hidden_state_input.get_shape(); - shape[shape.size() - 1] = shape [shape.size() - 1] * m_adjust_factor; + shape[shape.size() - 1] = shape[shape.size() - 1] * m_adjust_factor; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index f8049277f2..c871d074c8 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -103,7 +103,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p // to be implemented future SchedulerConfig scheduler_config_copy = scheduler_config; if (scheduler_config.dynamic_split_fuse) { - std::cout << "WARNING: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; + std::cout << "Note: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; scheduler_config_copy.dynamic_split_fuse = false; // Use scheduler_config_copy in subsequent code if modification is needed } @@ -161,7 +161,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( // to be implemented future SchedulerConfig scheduler_config_copy = scheduler_config; if (scheduler_config.dynamic_split_fuse) { - std::cout << "WARNING: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; + std::cout << "Note: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; scheduler_config_copy.dynamic_split_fuse = false; // Use scheduler_config_copy in subsequent code if modification is needed } @@ -225,7 +225,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( // to be implemented future SchedulerConfig scheduler_config_copy = scheduler_config; if (scheduler_config.dynamic_split_fuse) { - std::cout << "WARNING: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; + std::cout << "Note: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; scheduler_config_copy.dynamic_split_fuse = false; // Use scheduler_config_copy in subsequent code if modification is needed } diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index 441194364b..76f3c4a918 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -36,15 +36,14 @@ std::pair generation_config(const GenerationConfig& config) { inline void apply_eagle_rt_info(std::shared_ptr& model, ov::AnyMap& properties, const std::filesystem::path& mapping_path) { if (model->has_rt_info("eagle3_mode") && model->get_rt_info("eagle3_mode")) { + std::cout << "Eagle3 model is detected from rt_info. Applying eagle3_mode property." << std::endl; properties["eagle3_mode"] = true; + OPENVINO_ASSERT(model->has_rt_info("hidden_layers_list"), "hidden layers list is not found in eagle3 model rt_info"); + properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); if (!mapping_path.empty()) { properties["dt_mapping_path"] = mapping_path; // d2t mapping path } } - - if (model->has_rt_info("hidden_layers_list")) { - properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); - } } inline void apply_eagle_rt_info(std::shared_ptr& model, diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 86914886b8..83f6d906a4 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -387,7 +387,6 @@ bool EagleModelTransform::run_on_model(const std::shared_ptr& model) if (!m_new_results.empty()) { model->add_results(m_new_results); - std::cout << "EagleModelTransform - Added last hidden output " << std::endl; } // input transform for draft // here we apply a trick for the fc layer in draft model @@ -399,7 +398,6 @@ bool EagleModelTransform::run_on_model(const std::shared_ptr& model) manager.run_passes(model); model->add_parameters({m_new_parameters.back()}); - std::cout << "EagleModelTransform - trick on draft model inputs " << std::endl; } return true; } else { @@ -409,7 +407,6 @@ bool EagleModelTransform::run_on_model(const std::shared_ptr& model) manager.run_passes(model); if (!m_hidden_layer_outputs.empty()) { - std::cout << "EagleModelTransform - extracted intermediate hidden state outputs " << std::endl; auto concat = std::make_shared(m_hidden_layer_outputs, -1); concat->set_friendly_name("eagle3_hidden_states_concat"); @@ -418,8 +415,6 @@ bool EagleModelTransform::run_on_model(const std::shared_ptr& model) result->output(0).set_names({output_name}); result->set_friendly_name(output_name); model->add_results({result}); - - std::cout << "EagleModelTransform - Added concated eagle3 hidden state output" << std::endl; return true; } } @@ -461,7 +456,6 @@ bool EagleInputTransform::apply(NodePtr node, std::vector(internal_hidden_state, matmul_node->output(0)); ov::replace_node(matmul_node, new_eltwise); params.push_back(internal_hidden_state); - std::cout << "EagleInputTransform - Added internal hidden state input parameter" << std::endl; return true; } } From 6c3876ec20f53c255f908bf891d55458edba6df5 Mon Sep 17 00:00:00 2001 From: fishbell Date: Sat, 11 Oct 2025 22:54:26 +0800 Subject: [PATCH 14/43] share weights, rt info update Signed-off-by: fishbell --- .../src/continuous_batching/model_runner.hpp | 8 +- src/cpp/src/continuous_batching/pipeline.cpp | 24 +++- src/cpp/src/llm/pipeline.cpp | 4 +- .../speculative_decoding_impl.cpp | 104 +++++++++++------- .../speculative_decoding_impl.hpp | 6 +- 5 files changed, 90 insertions(+), 56 deletions(-) diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index 4c544e9836..f7f4e602fa 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -502,24 +502,24 @@ class ModelRunner { if (hidden_state_input && hidden_state_input.get_size() > 0) { if (m_is_hidden_state_import_needed) { try { - m_request.set_tensor("target_hidden_state_input", hidden_state_input); + m_request.set_tensor("hidden_states", hidden_state_input); auto shape = hidden_state_input.get_shape(); shape[shape.size() - 1] = shape[shape.size() - 1] / m_adjust_factor; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); - m_request.set_tensor("internal_hidden_state_input", fake_tensor); + m_request.set_tensor("internal_hidden_states", fake_tensor); } catch (const ov::Exception& e) { } } else { try { - m_request.set_tensor("internal_hidden_state_input", hidden_state_input); + m_request.set_tensor("internal_hidden_states", hidden_state_input); auto shape = hidden_state_input.get_shape(); shape[shape.size() - 1] = shape[shape.size() - 1] * m_adjust_factor; ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); auto fake_data = fake_tensor.data(); std::memset(fake_data, 0, fake_tensor.get_byte_size()); - m_request.set_tensor("target_hidden_state_input", fake_tensor); + m_request.set_tensor("hidden_states", fake_tensor); } catch (const ov::Exception& e) { } } diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index 08e5ac1239..d857434f43 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -16,6 +16,7 @@ #include "utils.hpp" #include "visual_language/inputs_embedder.hpp" #include "safe_tensor_wrapper.hpp" +#include "json_utils.hpp" using namespace ov::genai; @@ -36,7 +37,7 @@ struct Eagle3RTInfo { }; Eagle3RTInfo -extract_eagle_mode_from_config(ov::AnyMap& config) { +extract_eagle_mode_from_config(ov::AnyMap& config, const std::filesystem::path& models_path) { Eagle3RTInfo eagle_rt_info; if (config.find("eagle3_mode") != config.end()) { eagle_rt_info.eagle3_mode = config.at("eagle3_mode").as(); @@ -44,10 +45,23 @@ extract_eagle_mode_from_config(ov::AnyMap& config) { if (config.find("hidden_layers_list") != config.end()) { eagle_rt_info.hidden_layers_list = config.at("hidden_layers_list").as>(); config.erase("hidden_layers_list"); + } else { + // compute the layers from number of hidden layers + auto config_file_path = models_path / "config.json"; + if (!std::filesystem::exists(config_file_path)) + OPENVINO_THROW("cannot deduce layers for hidden layer extraction"); + std::ifstream file(config_file_path); + + nlohmann::json data = nlohmann::json::parse(file); + using ov::genai::utils::read_json_param; + size_t num_decoder_layers = 0; + read_json_param(data, "num_hidden_layers", num_decoder_layers); + OPENVINO_ASSERT(num_decoder_layers > 3, "num_decoder_layers is too small to deduce hidden layers for extraction"); + eagle_rt_info.hidden_layers_list = { 2, num_decoder_layers / 2, num_decoder_layers - 3 }; } if (config.find("dt_mapping_path") != config.end()) { eagle_rt_info.dt_mapping_table = config.at("dt_mapping_path").as(); - eagle_rt_info.dt_mapping_table = eagle_rt_info.dt_mapping_table / "eagle3.safetensor"; + eagle_rt_info.dt_mapping_table = eagle_rt_info.dt_mapping_table / "eagle3.safetensors"; config.erase("dt_mapping_path"); } } @@ -80,7 +94,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties); + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, models_path); auto model = utils::read_model(models_path, properties); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); @@ -141,7 +155,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties); + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, models_path); auto model = utils::read_model(models_path, properties_without_draft_model); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); properties_without_draft_model_without_gguf[ov::cache_model_path.name()] = models_path; @@ -202,7 +216,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties); + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, std::filesystem::path(model_str)); auto model = utils::singleton_core().read_model(model_str, weights_tensor); auto rt_info = model->get_rt_info(); diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index 76f3c4a918..d5e9c1d308 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -38,8 +38,8 @@ inline void apply_eagle_rt_info(std::shared_ptr& model, ov::AnyMap& p if (model->has_rt_info("eagle3_mode") && model->get_rt_info("eagle3_mode")) { std::cout << "Eagle3 model is detected from rt_info. Applying eagle3_mode property." << std::endl; properties["eagle3_mode"] = true; - OPENVINO_ASSERT(model->has_rt_info("hidden_layers_list"), "hidden layers list is not found in eagle3 model rt_info"); - properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); + if (model->has_rt_info("hidden_layers_list")) + properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); if (!mapping_path.empty()) { properties["dt_mapping_path"] = mapping_path; // d2t mapping path } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 83f6d906a4..a3a28cd6a9 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -366,6 +366,51 @@ std::vector ContinuousBatchingPipeline::SpeculativeDecodingI return main_awaiting_requests; } +void share_embedding_weights(std::shared_ptr& main_model, std::shared_ptr& draft_model) { + // extract embedding weight from main model + auto find_embedding_gather = [](const std::shared_ptr& model) + -> std::shared_ptr { + for (const auto& node : model->get_ordered_ops()) { + auto gather = std::dynamic_pointer_cast(node); + if (!gather) continue; + // [vocab, hidden_size] * [batch, seq_len] -> [batch, seq_len, hidden_size] + auto data_node = gather->input_value(0).get_node_shared_ptr(); + auto indices_node = gather->input_value(1).get_node_shared_ptr(); + if (!data_node || !indices_node) continue; + // indices_node should be on parameter path, maybe this is better rule + ov::PartialShape ps = data_node->get_output_partial_shape(0); + if (ps.rank().is_static() && ps.rank().get_length() >= 2) { + if (ps[0].is_static() && ps[0].get_length() > 1000) { // Heuristic: vocab size > 1000 + return gather; + } + } + std::string fname = data_node->get_friendly_name(); + if (fname.find("embed_tokens") != std::string::npos || + fname.find("embedding") != std::string::npos) { + return gather; + } + } + return nullptr; + }; + auto main_gather = find_embedding_gather(main_model); + auto draft_gather = find_embedding_gather(draft_model); + if (!main_gather || !draft_gather) { + return; + } + auto main_weight_node = main_gather->input_value(0).get_node_shared_ptr(); + auto draft_weight_node = draft_gather->input_value(0).get_node_shared_ptr(); + + if (main_weight_node.get() == draft_weight_node.get()) { + return; + } + + try { + draft_weight_node->output(0).replace(main_weight_node->output(0)); + } catch (...) { + std::cout << "fail to import embedding weights from main model to draft model" << std::endl; + } +} + void extract_hidden_state_generic(std::shared_ptr& model, const std::vector& hidden_layers_to_abstract) { ov::pass::Manager pm; @@ -377,28 +422,20 @@ EagleModelTransform::EagleModelTransform(const std::vector& layers) : m_lay } bool EagleModelTransform::run_on_model(const std::shared_ptr& model) { + // share the embedding weights from main model to draft model m_new_parameters.clear(); m_new_results.clear(); if (m_layer_ids.size() == 1 && m_layer_ids[0] == -1) { ov::pass::Manager manager; manager.set_per_pass_validation(false); - manager.register_pass(m_new_parameters, m_new_results); - manager.run_passes(model); - - if (!m_new_results.empty()) { - model->add_results(m_new_results); - } + manager.register_pass(m_new_results); // input transform for draft // here we apply a trick for the fc layer in draft model - { - ov::pass::Manager manager; - manager.set_per_pass_validation(false); - m_new_parameters = model->get_parameters(); - manager.register_pass(m_new_parameters); - manager.run_passes(model); - - model->add_parameters({m_new_parameters.back()}); - } + manager.register_pass(m_new_parameters); + manager.run_passes(model); + + model->add_parameters(m_new_parameters); + model->add_results(m_new_results); return true; } else { ov::pass::Manager manager; @@ -442,7 +479,7 @@ EagleInputTransform::EagleInputTransform(std::vector>& params) { if (ov::is_type(node)) { auto matmul_node = ov::as_type_ptr(node); - // check the input of matmul node, if it is a node with name "target_hidden_state_input", then it's the node we want + // check the input of matmul node, if it is a node with name "hidden_states", then it's the node we want auto input_node = matmul_node->get_input_node_shared_ptr(0); if (!ov::as_type_ptr(input_node)) { return false; @@ -450,8 +487,8 @@ bool EagleInputTransform::apply(NodePtr node, std::vectorget_output_partial_shape(0); auto internal_hidden_state = std::make_shared(node->get_element_type(), node->get_output_partial_shape(0)); - internal_hidden_state->output(0).set_names({"internal_hidden_state_input"}); - internal_hidden_state->set_friendly_name("internal_hidden_state_input"); + internal_hidden_state->output(0).set_names({"internal_hidden_states"}); + internal_hidden_state->set_friendly_name("internal_hidden_states"); // create new eltwise node to add output of MatMul node and internal hidden state input from last cycle of itself auto new_eltwise = std::make_shared(internal_hidden_state, matmul_node->output(0)); ov::replace_node(matmul_node, new_eltwise); @@ -460,13 +497,13 @@ bool EagleInputTransform::apply(NodePtr node, std::vector>& params, std::vector>& results) { +EagleBaseTransform::EagleBaseTransform(std::vector>& results) { register_matcher( std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), - ([¶ms, &results, this](ov::pass::pattern::Matcher& m) { + ([&results, this](ov::pass::pattern::Matcher& m) { auto node = m.get_match_root(); try { - if (apply(node, params, results)) { + if (apply(node, results)) { ++applied; return true; } @@ -514,7 +551,7 @@ std::shared_ptr EagleBaseTransform::find_last_residual_node(const std: return find_last_residual_node(start_node, visited_nodes); } -bool EagleBaseTransform::apply(NodePtr node, std::vector>& params, std::vector>& results) { +bool EagleBaseTransform::apply(NodePtr node, std::vector>& results) { { // 1. without normalization layer 2. add extra input if (ov::is_type(node)) { @@ -574,31 +611,15 @@ Eagle3Transform::Eagle3Transform(const std::vector& layers, std::vector(hidden_layer, "Eagle3Transform::hidden_extraction"), [&hidden_state_outputs, this](ov::pass::pattern::Matcher& m) { auto node = m.get_match_root(); - try { - if (apply(node, hidden_state_outputs)) { - ++applied; // FIXME: For debugging purposes only - return true; - } - } catch (...) { - OPENVINO_ASSERT(false, "Eagle3Transform failed to apply"); + if (ov::is_type(node)) { + hidden_state_outputs.push_back(node->output(0)); + return true; } return false; } ); } -bool Eagle3Transform::apply(NodePtr node, std::vector>& hidden_state_outputs) { - if (ov::is_type(node)) { - auto add_node = std::dynamic_pointer_cast(node); - if (!add_node) { - return false; - } - hidden_state_outputs.push_back(add_node->output(0)); - return true; - } - return false; -} - ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc, const std::vector& hidden_layers) @@ -622,6 +643,7 @@ ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai // apply transformations needed to run eagle model // target model: hidden state extraction, draft model: hidden state import , hidden state extraction // eagle3 specific : dt importing + share_embedding_weights(main_model, draft_model); extract_hidden_state_generic(main_model, hidden_layers); extract_hidden_state_generic(draft_model, { -1 }); diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 8a2afc0fa6..55796519dd 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -123,12 +123,12 @@ class EagleBaseTransform : public ov::pass::MatcherPass { public: using NodePtr = std::shared_ptr; OPENVINO_MATCHER_PASS_RTTI("EagleBaseTransform"); - EagleBaseTransform(std::vector>& params, std::vector>& results); + EagleBaseTransform(std::vector>& results); ~EagleBaseTransform() = default; private: - bool apply(NodePtr node, std::vector>& params, std::vector>& results); + bool apply(NodePtr node, std::vector>& results); size_t applied = 0; std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node); std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node, @@ -155,8 +155,6 @@ class Eagle3Transform : public ov::pass::MatcherPass { ~Eagle3Transform() = default; private: - bool apply(NodePtr node, std::vector>& hidden_state_outputs); - size_t applied = 0; std::vector m_layers; // layers to be abstracted }; From f33491f1b6a096682dfb34a69400db0c38d936da Mon Sep 17 00:00:00 2001 From: fishbell Date: Sat, 11 Oct 2025 23:00:26 +0800 Subject: [PATCH 15/43] reuse spec app for eagle Signed-off-by: fishbell --- .../accuracy/CMakeLists.txt | 4 - .../continuous_batching_eagle_decoding.cpp | 150 ------------------ 2 files changed, 154 deletions(-) delete mode 100644 tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp diff --git a/tools/continuous_batching/accuracy/CMakeLists.txt b/tools/continuous_batching/accuracy/CMakeLists.txt index b3a29697d3..8223452b5c 100644 --- a/tools/continuous_batching/accuracy/CMakeLists.txt +++ b/tools/continuous_batching/accuracy/CMakeLists.txt @@ -33,10 +33,6 @@ set(TARGET_NAME_CB continuous_batching_speculative_decoding) add_executable(${TARGET_NAME_CB} ${TARGET_NAME_CB}.cpp) target_link_libraries(${TARGET_NAME_CB} PRIVATE openvino::genai cxxopts::cxxopts) -set(TARGET_NAME_CB_EAGLE continuous_batching_eagle_decoding) -add_executable(${TARGET_NAME_CB_EAGLE} ${TARGET_NAME_CB_EAGLE}.cpp) -target_link_libraries(${TARGET_NAME_CB_EAGLE} PRIVATE openvino::genai cxxopts::cxxopts) - set_target_properties(${TARGET_NAME} ${TARGET_NAME_CB} PROPERTIES # Ensure out of box LC_RPATH on macOS with SIP INSTALL_RPATH_USE_LINK_PATH ON) diff --git a/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp b/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp deleted file mode 100644 index 395099cc43..0000000000 --- a/tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (C) 2023-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 - -#include -#include - -#include "openvino/genai/continuous_batching_pipeline.hpp" - -void print_cb_generation_result(const ov::genai::GenerationResult& generation_result) { - for (size_t output_id = 0; output_id < generation_result.m_generation_ids.size(); ++output_id) { - std::cout << "Answer " << output_id << " (" << generation_result.m_scores[output_id] << ") : " << generation_result.m_generation_ids[output_id] << std::endl; - } -} - -std::vector get_spec_decoding_generation_config_examples() { - - // sampling param for speulative decoding - ov::genai::GenerationConfig generation_config_greedy_constant = ov::genai::greedy(); - { - generation_config_greedy_constant.num_assistant_tokens = 5; - } - - ov::genai::GenerationConfig generation_config_multinomial_constant =ov::genai::greedy(); - { - generation_config_multinomial_constant.num_return_sequences = 1; - generation_config_multinomial_constant.num_assistant_tokens = 5; - } - - ov::genai::GenerationConfig generation_config_greedy_dynamic = ov::genai::greedy(); - { - generation_config_greedy_dynamic.num_assistant_tokens = 4; - } - - ov::genai::GenerationConfig generation_config_multinomial_dynamic = ov::genai::greedy(); - { - generation_config_multinomial_dynamic.num_return_sequences = 1; - generation_config_multinomial_dynamic.num_assistant_tokens = 4; - } - - return { - generation_config_greedy_constant, - generation_config_multinomial_constant, - generation_config_greedy_dynamic, - generation_config_multinomial_dynamic, - }; -} - -int main(int argc, char* argv[]) try { - // Command line options - - cxxopts::Options options("accuracy_sample", "Help command"); - - options.add_options() - ("n,num_prompts", "A number of prompts", cxxopts::value()->default_value("1")) - ("m,model", "Path to model and tokenizers base directory", cxxopts::value()->default_value(".")) - ("a,draft_model", "Path to assisting model base directory", cxxopts::value()->default_value(".")) - ("d,device", "Target device to run the model", cxxopts::value()->default_value("GPU")) - ("h,help", "Print usage"); - - cxxopts::ParseResult result; - try { - result = options.parse(argc, argv); - } catch (const cxxopts::exceptions::exception& e) { - std::cout << e.what() << "\n\n"; - std::cout << options.help() << std::endl; - return EXIT_FAILURE; - } - - if (result.count("help")) { - std::cout << options.help() << std::endl; - return EXIT_SUCCESS; - } - - const size_t num_prompts = result["num_prompts"].as(); - const std::string models_path = result["model"].as(); - const std::string draft_models_path = result["draft_model"].as(); - const std::string device = result["device"].as(); - - std::vector prompt_examples = { - "What is OpenVINO?", - "How are you?", - "What is your name?", - "Tell me something about Canada", - "What is OpenVINO?", - }; - - auto generation_config = get_spec_decoding_generation_config_examples(); - auto default_config_size = generation_config.size(); - std::vector cb_generation_config; - for (size_t i = 0; i < num_prompts; ++i) { - cb_generation_config.push_back(generation_config[i % default_config_size]); - } - - std::vector prompts(num_prompts); - for (size_t i = 0; i < num_prompts; ++i) { - prompts[i] = prompt_examples[i % prompt_examples.size()]; - } - - ov::genai::SchedulerConfig scheduler_config; - // batch size - scheduler_config.max_num_batched_tokens = 128; - // cache params - scheduler_config.num_kv_blocks = 364; - // mode - vLLM or dynamic_split_fuse - scheduler_config.dynamic_split_fuse = false; // does not support true in eagle speculative decoding - // vLLM specific params - scheduler_config.max_num_seqs = 2; - - ov::genai::ContinuousBatchingPipeline pipe(models_path, scheduler_config, device, {ov::genai::draft_model(draft_models_path, device)}); - std::vector generation_results = pipe.generate(prompts, cb_generation_config); - - for (size_t request_id = 0; request_id < generation_results.size(); ++request_id) { - const ov::genai::GenerationResult & generation_result = generation_results[request_id]; - std::cout << "Question: " << prompts[request_id] << std::endl; - switch (generation_result.m_status) - { - case ov::genai::GenerationStatus::FINISHED: - print_cb_generation_result(generation_result); - break; - case ov::genai::GenerationStatus::IGNORED: - std::cout << "Request was ignored due to lack of memory." < 0) { - std::cout << "Partial result:" << std::endl; - print_cb_generation_result(generation_result); - } - break; - case ov::genai::GenerationStatus::STOP: - case ov::genai::GenerationStatus::CANCEL: - std::cout << "Request was aborted." < 0) { - std::cout << "Partial result:" << std::endl; - print_cb_generation_result(generation_result); - } - break; - default: - break; - } - std::cout << std::endl; - } -} catch (const std::exception& error) { - try { - std::cerr << error.what() << '\n'; - } catch (const std::ios_base::failure&) {} - return EXIT_FAILURE; -} catch (...) { - try { - std::cerr << "Non-exception object thrown\n"; - } catch (const std::ios_base::failure&) {} - return EXIT_FAILURE; -} From 264540feabed901721c5c50c721ec9f9beb5cf96 Mon Sep 17 00:00:00 2001 From: fishbell Date: Sun, 12 Oct 2025 00:42:26 +0800 Subject: [PATCH 16/43] fix build warning Signed-off-by: fishbell --- src/cpp/src/continuous_batching/pipeline.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index d857434f43..9bd1bbb026 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -54,7 +54,7 @@ extract_eagle_mode_from_config(ov::AnyMap& config, const std::filesystem::path& nlohmann::json data = nlohmann::json::parse(file); using ov::genai::utils::read_json_param; - size_t num_decoder_layers = 0; + int num_decoder_layers = 0; read_json_param(data, "num_hidden_layers", num_decoder_layers); OPENVINO_ASSERT(num_decoder_layers > 3, "num_decoder_layers is too small to deduce hidden layers for extraction"); eagle_rt_info.hidden_layers_list = { 2, num_decoder_layers / 2, num_decoder_layers - 3 }; From cfe77f84ef1d6940c96274a522260806c159e3b0 Mon Sep 17 00:00:00 2001 From: fishbell Date: Tue, 14 Oct 2025 00:34:26 +0800 Subject: [PATCH 17/43] enable test Signed-off-by: fishbell --- tests/python_tests/samples/conftest.py | 8 ++++ .../samples/test_speculative_decoding_lm.py | 40 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/tests/python_tests/samples/conftest.py b/tests/python_tests/samples/conftest.py index 454d7a1cb7..70b76b4665 100644 --- a/tests/python_tests/samples/conftest.py +++ b/tests/python_tests/samples/conftest.py @@ -143,6 +143,14 @@ "tiny-random-SpeechT5ForTextToSpeech": { "name": "hf-internal-testing/tiny-random-SpeechT5ForTextToSpeech", "convert_args": ["--model-kwargs", json.dumps({"vocoder": "fxmarty/speecht5-hifigan-tiny"})] + }, + "EAGLE-LLaMA3.1-Instruct-8B": { + "name": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", + "convert_args": ['--trust-remote-code', "--eagle3"] + }, + "qwen3_8b_eagle3": { + "name": "Tengyunw/qwen3_8b_eagle3", + "convert_args": ['--trust-remote-code', "--eagle3"] } } diff --git a/tests/python_tests/samples/test_speculative_decoding_lm.py b/tests/python_tests/samples/test_speculative_decoding_lm.py index 35a1bb285a..fa5034ffe1 100644 --- a/tests/python_tests/samples/test_speculative_decoding_lm.py +++ b/tests/python_tests/samples/test_speculative_decoding_lm.py @@ -45,3 +45,43 @@ def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" +test_prompt = """Code: +def add(a, b): + return a + b + +Question: Can you please add 2 and 3 +A:""" +class TestEagle3SpeculativeDecodingLM: + @pytest.mark.llm + @pytest.mark.samples + @pytest.mark.parametrize( + "convert_model, convert_draft_model, sample_args", + [ + pytest.param("Qwen2-0.5B-Instruct", "EAGLE-LLaMA3.1-Instruct-8B", test_prompt), + pytest.param("Qwen2-0.5B-Instruct", "qwen3_8b_eagle3", test_prompt), + ], + indirect=["convert_model", "convert_draft_model"], + ) + def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model, sample_args): + if sys.platform == 'darwin': + pytest.xfail("Ticket 173586") + env = os.environ.copy() + env["OPENVINO_LOG_LEVEL"] = "0" + # Test CPP sample + cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'speculative_decoding_lm') + cpp_command =[cpp_sample, convert_model, convert_draft_model, sample_args] + cpp_result = run_sample(cpp_command, env=env) + + # Test Python sample + py_script = os.path.join(SAMPLES_PY_DIR, "text_generation/speculative_decoding_lm.py") + py_command = [sys.executable, py_script, convert_model, convert_draft_model, sample_args] + py_result = run_sample(py_command, env=env) + + # Greedy decoding + cpp_sample_ref = os.path.join(SAMPLES_CPP_DIR, 'greedy_causal_lm') + cpp_command_ref = [cpp_sample_ref, convert_model, sample_args] + cpp_result_ref = run_sample(cpp_command_ref, env=env) + + # Compare results + assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" + assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" \ No newline at end of file From 2f5d080171da707b300a5538934fe3e9477eb278 Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 15 Oct 2025 00:21:36 +0800 Subject: [PATCH 18/43] skip eagle3 test for now Signed-off-by: fishbell --- .../_components/llm-models-table/models.ts | 2 + tests/python_tests/samples/conftest.py | 12 ++-- .../samples/test_speculative_decoding_lm.py | 3 +- .../python_tests/test_continuous_batching.py | 55 ++++++++++++++++++- 4 files changed, 62 insertions(+), 10 deletions(-) diff --git a/site/docs/supported-models/_components/llm-models-table/models.ts b/site/docs/supported-models/_components/llm-models-table/models.ts index 6208459f10..0676b188b7 100644 --- a/site/docs/supported-models/_components/llm-models-table/models.ts +++ b/site/docs/supported-models/_components/llm-models-table/models.ts @@ -482,6 +482,8 @@ export const LLM_MODELS: LLMModelType[] = [ 'https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct', 'https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B', 'https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B', + 'https://huggingface.co/yuhuili/EAGLE3-LLaMA3.1-Instruct-8B', + 'https://huggingface.co/Tengyunw/qwen3_8b_eagle3' ], }, { diff --git a/tests/python_tests/samples/conftest.py b/tests/python_tests/samples/conftest.py index 70b76b4665..bf5c3adbb6 100644 --- a/tests/python_tests/samples/conftest.py +++ b/tests/python_tests/samples/conftest.py @@ -144,13 +144,13 @@ "name": "hf-internal-testing/tiny-random-SpeechT5ForTextToSpeech", "convert_args": ["--model-kwargs", json.dumps({"vocoder": "fxmarty/speecht5-hifigan-tiny"})] }, - "EAGLE-LLaMA3.1-Instruct-8B": { - "name": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", - "convert_args": ['--trust-remote-code', "--eagle3"] + "Qwen3-1.7B": { + "name": "Qwen/Qwen3-1.7B", + "convert_args": ["--task", "text-generation-with-past", '--trust-remote-code'] }, - "qwen3_8b_eagle3": { - "name": "Tengyunw/qwen3_8b_eagle3", - "convert_args": ['--trust-remote-code', "--eagle3"] + "qwen3_1.7b_eagle3": { + "name": "AngelSlim/Qwen3-1.7B_eagle3", + "convert_args": ["--task", "text-generation-with-past", "--trust-remote-code", "--eagle3"] } } diff --git a/tests/python_tests/samples/test_speculative_decoding_lm.py b/tests/python_tests/samples/test_speculative_decoding_lm.py index fa5034ffe1..602745cc51 100644 --- a/tests/python_tests/samples/test_speculative_decoding_lm.py +++ b/tests/python_tests/samples/test_speculative_decoding_lm.py @@ -57,8 +57,7 @@ class TestEagle3SpeculativeDecodingLM: @pytest.mark.parametrize( "convert_model, convert_draft_model, sample_args", [ - pytest.param("Qwen2-0.5B-Instruct", "EAGLE-LLaMA3.1-Instruct-8B", test_prompt), - pytest.param("Qwen2-0.5B-Instruct", "qwen3_8b_eagle3", test_prompt), + pytest.param("Qwen3-1.7B", "qwen3_1.7b_eagle3", test_prompt, marks=pytest.mark.skip(reason = 'CVS-171947, CVS-171943')), ], indirect=["convert_model", "convert_draft_model"], ) diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index a2a9efc1a2..83a54bf645 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -475,9 +475,12 @@ def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_c return pipe, prompt, generation_config -def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType): +def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType, draft_model_id: str = None): _, _, model_path = download_and_convert_model(model_id) - ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type) + draft_model_path = None + if draft_model_id: + _, _, draft_model_path = download_and_convert_model(draft_model_id) + ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type, draft_model_path = draft_model_path) return ov_pipe.generate([prompt], generation_config).extended_perf_metrics @@ -528,3 +531,51 @@ def test_speculative_decoding_extended_perf_metrics(pipeline_type): assert std_gen_duration == 0 else: assert extended_perf_metrics is None + +@pytest.mark.parametrize("pipeline_type", [PipelineType.SPECULATIVE_DECODING]) +@pytest.mark.precommit +@pytest.mark.skip(reason="CVS-171943 enable model conversion for eagle3 and enable the test") +def test_eagle3_decoding_extended_perf_metrics(pipeline_type): + import time + extended_perf_metrics = None + start_time = time.perf_counter() + model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + draft_model_id : str = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + extended_perf_metrics = run_extended_perf_metrics_collection(model_id, draft_model_id, generation_config, "Why is the Sun yellow?", pipeline_type) + total_time = (time.perf_counter() - start_time) * 1000 + + assert not extended_perf_metrics is None + assert not extended_perf_metrics.main_model_metrics is None + assert not extended_perf_metrics.draft_model_metrics is None + + assert extended_perf_metrics.get_num_accepted_tokens() > 0 + + num_generated_tokens_main = extended_perf_metrics.main_model_metrics.get_num_generated_tokens() + assert num_generated_tokens_main > 0 and num_generated_tokens_main <= generation_config.max_new_tokens + + # max num_generated_tokens for draft model will be reached if it will generate num_assistant_tokens at each step + # plus fist token, which was generated by main model + num_generated_tokens_draft = extended_perf_metrics.draft_model_metrics.get_num_generated_tokens() + assert num_generated_tokens_draft > 0 and num_generated_tokens_draft < ((generation_config.max_new_tokens - 1) * generation_config.num_assistant_tokens + 1) + + total_iteration_number_main = len(extended_perf_metrics.main_model_metrics.raw_metrics.m_durations) + assert total_iteration_number_main > 0 and total_iteration_number_main <= generation_config.max_new_tokens + + total_iteration_number_draft = len(extended_perf_metrics.draft_model_metrics.raw_metrics.m_durations) + assert total_iteration_number_draft > 0 and total_iteration_number_draft < ((generation_config.max_new_tokens - 1) * generation_config.num_assistant_tokens + 1) + + for model_metrics in [extended_perf_metrics.main_model_metrics, extended_perf_metrics.draft_model_metrics]: + mean_ttst, std_ttst = model_metrics.get_ttst() + assert (mean_ttst, std_ttst) == (model_metrics.get_ttst().mean, model_metrics.get_ttst().std) + assert mean_ttst > 0 and mean_ttst < model_metrics.get_ttft().mean + assert std_ttst == 0 + + mean_latency, std_latency = model_metrics.get_latency() + assert (mean_latency, std_latency) == (model_metrics.get_latency().mean, model_metrics.get_latency().std) + assert mean_latency > 0 and mean_latency < 1000.0 + + mean_gen_duration, std_gen_duration = model_metrics.get_generate_duration() + assert (mean_gen_duration, std_gen_duration) == (model_metrics.get_generate_duration().mean, model_metrics.get_generate_duration().std) + assert mean_gen_duration > 0 and mean_gen_duration < total_time + assert std_gen_duration == 0 \ No newline at end of file From ff9b50e01bbafb1f607ecaad0c860b2d0a1b5444 Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 15 Oct 2025 18:16:56 +0800 Subject: [PATCH 19/43] align default num assistant tokens Signed-off-by: fishbell --- .../speculative_decoding_impl.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 9b4a027981..179f9a8e82 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -788,11 +788,18 @@ std::vector ContinuousBatchingPipeline::EagleDecodingIm auto new_input_ids = input_ids[request_id]; OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); auto main_sampling_params = sampling_params[request_id]; - + OPENVINO_ASSERT( + main_sampling_params.assistant_confidence_threshold == 0.f, + "Eagle3 Speculative Decoding pipeline only supports `num_assistant_tokens` " + "as parameter in GenerationConfig and doesn't work with `assistant_confidence_threshold`.\nPlease " + "remove its specification or set it to 0.f."); + if (main_sampling_params.num_assistant_tokens == 0) { + main_sampling_params.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; + } main_generations.push_back( m_main_pipeline->add_request(request_id, new_input_ids, main_sampling_params)); - auto draft_sampling_params = sampling_params[request_id]; + auto draft_sampling_params = main_sampling_params; // set the parameters do not stop draft generation without stopping of the same request for main pipeline draft_sampling_params.ignore_eos = true; draft_sampling_params.stop_strings = {}; From 723e3f8d6ced30de68a9e4c6d3e2d28ef2a3a41a Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 15 Oct 2025 18:24:48 +0800 Subject: [PATCH 20/43] udpate test ticket Signed-off-by: fishbell --- tests/python_tests/samples/test_speculative_decoding_lm.py | 2 +- tests/python_tests/test_continuous_batching.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python_tests/samples/test_speculative_decoding_lm.py b/tests/python_tests/samples/test_speculative_decoding_lm.py index 602745cc51..09087b12e4 100644 --- a/tests/python_tests/samples/test_speculative_decoding_lm.py +++ b/tests/python_tests/samples/test_speculative_decoding_lm.py @@ -57,7 +57,7 @@ class TestEagle3SpeculativeDecodingLM: @pytest.mark.parametrize( "convert_model, convert_draft_model, sample_args", [ - pytest.param("Qwen3-1.7B", "qwen3_1.7b_eagle3", test_prompt, marks=pytest.mark.skip(reason = 'CVS-171947, CVS-171943')), + pytest.param("Qwen3-1.7B", "qwen3_1.7b_eagle3", test_prompt, marks=pytest.mark.skip(reason = 'CVS-171947, CVS-171943, CVS-174959')), ], indirect=["convert_model", "convert_draft_model"], ) diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index 83a54bf645..5665cd5eaf 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -534,13 +534,13 @@ def test_speculative_decoding_extended_perf_metrics(pipeline_type): @pytest.mark.parametrize("pipeline_type", [PipelineType.SPECULATIVE_DECODING]) @pytest.mark.precommit -@pytest.mark.skip(reason="CVS-171943 enable model conversion for eagle3 and enable the test") +@pytest.mark.skip(reason="CVS-174959 enable model conversion for eagle3 and enable the test") def test_eagle3_decoding_extended_perf_metrics(pipeline_type): import time extended_perf_metrics = None start_time = time.perf_counter() - model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - draft_model_id : str = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + model_id : str = "Qwen/Qwen3-1.7B" + draft_model_id : str = "AngelSlim/Qwen3-1.7B_eagle3" generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) extended_perf_metrics = run_extended_perf_metrics_collection(model_id, draft_model_id, generation_config, "Why is the Sun yellow?", pipeline_type) total_time = (time.perf_counter() - start_time) * 1000 From 8b72711ba0d367aab00c0dd0651eaabc971a7ce6 Mon Sep 17 00:00:00 2001 From: fishbell Date: Thu, 16 Oct 2025 17:41:38 +0800 Subject: [PATCH 21/43] apply review comment Signed-off-by: fishbell --- .../src/continuous_batching/model_runner.hpp | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index f7f4e602fa..77d9215324 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -51,7 +51,7 @@ class ModelRunner { bool m_is_hidden_state_export_needed = false; // need to export hidden state after inference bool m_is_hidden_state_import_needed = false; // need to import hidden state from another model runner - bool m_is_hidden_state_internal_needed = false; // need to use internal hidden state, e.g, eagle2 + bool m_is_hidden_state_internal_needed = false; // need to use internal hidden state, e.g, eagle SD std::map, std::pair> m_sequence_hidden_state_mapping; // pre-requisite: main/draft have same seq group and running seq grouped id // a container which use sequence group id and request id as key to store hidden states std::map m_initial_hidden_states; // shape: [N, seq_len, hidden_size] @@ -154,13 +154,13 @@ class ModelRunner { m_cache_rotation_deltas_for_each_layer = std::move(rotation_deltas_for_each_layer); } - ov::Tensor get_hidden_state(size_t request_id, size_t seq_grouped_id) const { + ov::Tensor get_hidden_state(uint64_t request_id, uint64_t seq_grouped_id) const { if (m_hidden_states.get_size() == 0) { return ov::Tensor(); } - auto key = std::make_pair(request_id, seq_grouped_id); - auto it = m_sequence_hidden_state_mapping.find(key); + const auto key = std::make_pair(request_id, seq_grouped_id); + const auto it = m_sequence_hidden_state_mapping.find(key); if (it == m_sequence_hidden_state_mapping.end()) { return ov::Tensor(); } @@ -173,8 +173,6 @@ class ModelRunner { return ov::Tensor(); } - size_t hidden_size = shape[shape.size() - 1]; - ov::Coordinate start_coord(shape.size(), 0); ov::Coordinate end_coord(shape.size(), 0); @@ -189,7 +187,7 @@ class ModelRunner { return ov::Tensor(m_hidden_states, start_coord, end_coord); } - void set_initial_hidden_state(size_t request_id, const ov::Tensor& hidden_state) { + void set_initial_hidden_state(uint64_t& request_id, const ov::Tensor& hidden_state) { m_initial_hidden_states[request_id] = hidden_state; } @@ -349,7 +347,6 @@ class ModelRunner { m_sequence_hidden_state_mapping[key] = std::make_pair(start_token_idx, sequence_length); } if (m_is_hidden_state_import_needed && hidden_state_data && hidden_size > 0) { - //auto key = std::make_pair(sequence_group->get_request_id(), sequence->get_grouped_id()); auto it = m_initial_hidden_states.find(sequence_group->get_request_id()); if (it != m_initial_hidden_states.end()) { @@ -370,7 +367,7 @@ class ModelRunner { size_t source_start_idx = stored_seq_len >= copy_length ? stored_seq_len - copy_length : 0; - copy_roi_between_tensors(stored_hidden_state, source_start_idx, copy_length, hidden_state_input, current_token_idx); + _copy_roi_between_tensors(stored_hidden_state, source_start_idx, copy_length, hidden_state_input, current_token_idx); } } } @@ -395,7 +392,7 @@ class ModelRunner { size_t src_start_idx = seq_len >= copy_length ? seq_len - copy_length : 0; auto target_shape = ov::Shape{num_scheduled_tokens, 1, hidden_size}; ov::Tensor target_base(ov::element::f32, target_shape, hidden_state_data + current_token_idx * hidden_size); - copy_roi_between_tensors(hidden_state, src_start_idx, copy_length, target_base, 0); + _copy_roi_between_tensors(hidden_state, src_start_idx, copy_length, target_base, 0); } } } @@ -663,7 +660,7 @@ class ModelRunner { // copy_length: number of elements along first dim to copy // dst_base: destination base tensor (may be full buffer or a wrapper around a raw pointer) // dst_first_dim_start: start index in first dimension of dst_base where copy should be placed - static void copy_roi_between_tensors(const ov::Tensor& src, + static void _copy_roi_between_tensors(const ov::Tensor& src, size_t src_start_idx, size_t copy_length, const ov::Tensor& dst_base, From b31411ad693283030c73b49fff495f54fa4aea72 Mon Sep 17 00:00:00 2001 From: fishbell Date: Thu, 16 Oct 2025 21:15:22 +0800 Subject: [PATCH 22/43] move eagle3 tests to seperate file Signed-off-by: fishbell --- .github/workflows/linux.yml | 2 +- .github/workflows/windows.yml | 2 +- .../python_tests/test_continuous_batching.py | 58 +---------- tests/python_tests/test_eagle3.py | 97 +++++++++++++++++++ 4 files changed, 102 insertions(+), 57 deletions(-) create mode 100644 tests/python_tests/test_eagle3.py diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 528a6825af..99145966b3 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -515,7 +515,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }} timeout: 360 - name: 'LLM & VLM' - cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/' + cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_eagle3.py --override-ini cache_dir=/mount/caches/pytest/' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }} timeout: 180 - name: 'GGUF Reader tests' diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index ab4a8957a9..5d639af220 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -616,7 +616,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }} timeout: 360 - name: 'LLM & VLM' - cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/' + cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_eagle3.py --override-ini cache_dir=/mount/caches/pytest/' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }} timeout: 180 - name: 'GGUF Reader tests' diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index 5665cd5eaf..9dacdfde51 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -474,13 +474,9 @@ def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_c raise RuntimeError(f"{pipeline_type} is unknown pipeline type!") return pipe, prompt, generation_config - -def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType, draft_model_id: str = None): +def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType): _, _, model_path = download_and_convert_model(model_id) - draft_model_path = None - if draft_model_id: - _, _, draft_model_path = download_and_convert_model(draft_model_id) - ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type, draft_model_path = draft_model_path) + ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type) return ov_pipe.generate([prompt], generation_config).extended_perf_metrics @@ -530,52 +526,4 @@ def test_speculative_decoding_extended_perf_metrics(pipeline_type): assert mean_gen_duration > 0 and mean_gen_duration < total_time assert std_gen_duration == 0 else: - assert extended_perf_metrics is None - -@pytest.mark.parametrize("pipeline_type", [PipelineType.SPECULATIVE_DECODING]) -@pytest.mark.precommit -@pytest.mark.skip(reason="CVS-174959 enable model conversion for eagle3 and enable the test") -def test_eagle3_decoding_extended_perf_metrics(pipeline_type): - import time - extended_perf_metrics = None - start_time = time.perf_counter() - model_id : str = "Qwen/Qwen3-1.7B" - draft_model_id : str = "AngelSlim/Qwen3-1.7B_eagle3" - generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) - extended_perf_metrics = run_extended_perf_metrics_collection(model_id, draft_model_id, generation_config, "Why is the Sun yellow?", pipeline_type) - total_time = (time.perf_counter() - start_time) * 1000 - - assert not extended_perf_metrics is None - assert not extended_perf_metrics.main_model_metrics is None - assert not extended_perf_metrics.draft_model_metrics is None - - assert extended_perf_metrics.get_num_accepted_tokens() > 0 - - num_generated_tokens_main = extended_perf_metrics.main_model_metrics.get_num_generated_tokens() - assert num_generated_tokens_main > 0 and num_generated_tokens_main <= generation_config.max_new_tokens - - # max num_generated_tokens for draft model will be reached if it will generate num_assistant_tokens at each step - # plus fist token, which was generated by main model - num_generated_tokens_draft = extended_perf_metrics.draft_model_metrics.get_num_generated_tokens() - assert num_generated_tokens_draft > 0 and num_generated_tokens_draft < ((generation_config.max_new_tokens - 1) * generation_config.num_assistant_tokens + 1) - - total_iteration_number_main = len(extended_perf_metrics.main_model_metrics.raw_metrics.m_durations) - assert total_iteration_number_main > 0 and total_iteration_number_main <= generation_config.max_new_tokens - - total_iteration_number_draft = len(extended_perf_metrics.draft_model_metrics.raw_metrics.m_durations) - assert total_iteration_number_draft > 0 and total_iteration_number_draft < ((generation_config.max_new_tokens - 1) * generation_config.num_assistant_tokens + 1) - - for model_metrics in [extended_perf_metrics.main_model_metrics, extended_perf_metrics.draft_model_metrics]: - mean_ttst, std_ttst = model_metrics.get_ttst() - assert (mean_ttst, std_ttst) == (model_metrics.get_ttst().mean, model_metrics.get_ttst().std) - assert mean_ttst > 0 and mean_ttst < model_metrics.get_ttft().mean - assert std_ttst == 0 - - mean_latency, std_latency = model_metrics.get_latency() - assert (mean_latency, std_latency) == (model_metrics.get_latency().mean, model_metrics.get_latency().std) - assert mean_latency > 0 and mean_latency < 1000.0 - - mean_gen_duration, std_gen_duration = model_metrics.get_generate_duration() - assert (mean_gen_duration, std_gen_duration) == (model_metrics.get_generate_duration().mean, model_metrics.get_generate_duration().std) - assert mean_gen_duration > 0 and mean_gen_duration < total_time - assert std_gen_duration == 0 \ No newline at end of file + assert extended_perf_metrics is None \ No newline at end of file diff --git a/tests/python_tests/test_eagle3.py b/tests/python_tests/test_eagle3.py new file mode 100644 index 0000000000..2c525a66f6 --- /dev/null +++ b/tests/python_tests/test_eagle3.py @@ -0,0 +1,97 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import openvino_genai as ov_genai +from utils.hugging_face import download_and_convert_model, run_hugging_face +from utils.comparation import compare_generation_results +from utils.ov_genai_pipelines import create_ov_pipeline, PipelineType, convert_decoded_results_to_generation_result + +eagle_models_and_input = [ + ("Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3", "Why is the Sun yellow?")] + +devices = [ + ('CPU', 'CPU'), + ('GPU', 'GPU') +] +@pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) +@pytest.mark.parametrize("main_device,draft_device", devices) +@pytest.mark.precommit +@pytest.mark.skip(reason="CVS-174959 enable model conversion for eagle3 and enable the test") +def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_device, prompt): + # Download and convert model: + main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) + print("finised") + __, __, draft_model_path = download_and_convert_model(draft_model) + + # Create OpenVINO GenAI pipeline: + + ov_pipe = create_ov_pipeline(main_model_path, pipeline_type = PipelineType.SPECULATIVE_DECODING, draft_model_path = draft_model_path) + + # Run reference HF model: + ov_generation_config = ov_genai.GenerationConfig(max_new_tokens=20) + main_hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + ref_gen_results = run_hugging_face(main_opt_model, main_hf_tokenizer, [prompt], ov_generation_config) + + # Run OpenVINO GenAI pipeline: + ov_decoded_results = ov_pipe.generate([prompt], ov_generation_config) + ov_gen_results = convert_decoded_results_to_generation_result(ov_decoded_results, 1, 1, False) + + del ov_pipe + + # Compare results: + compare_generation_results([prompt], ref_gen_results, ov_gen_results, ov_generation_config) + +@pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) +@pytest.mark.parametrize("main_device,draft_device", devices) +@pytest.mark.precommit +@pytest.mark.skip(reason="CVS-174959 enable model conversion for eagle3 and enable the test") +def test_eagle3_sd_extended_perf_metrics(main_model, main_device, draft_model, draft_device, prompt): + import time + extended_perf_metrics = None + start_time = time.perf_counter() + generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + # Download and convert model: + main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) + __, __, draft_model_path = download_and_convert_model(draft_model) + + # Create OpenVINO GenAI pipeline: + ov_pipe = create_ov_pipeline(main_model_path, pipeline_type = PipelineType.SPECULATIVE_DECODING, draft_model_path = draft_model_path) + extended_perf_metrics = ov_pipe.generate([prompt], generation_config).extended_perf_metrics + total_time = (time.perf_counter() - start_time) * 1000 + + assert not extended_perf_metrics is None + assert not extended_perf_metrics.main_model_metrics is None + assert not extended_perf_metrics.draft_model_metrics is None + + assert extended_perf_metrics.get_num_accepted_tokens() > 0 + + num_generated_tokens_main = extended_perf_metrics.main_model_metrics.get_num_generated_tokens() + assert num_generated_tokens_main > 0 and num_generated_tokens_main <= generation_config.max_new_tokens + + # max num_generated_tokens for draft model will be reached if it will generate num_assistant_tokens at each step + # plus fist token, which was generated by main model + num_generated_tokens_draft = extended_perf_metrics.draft_model_metrics.get_num_generated_tokens() + assert num_generated_tokens_draft > 0 and num_generated_tokens_draft < ((generation_config.max_new_tokens - 1) * generation_config.num_assistant_tokens + 1) + + total_iteration_number_main = len(extended_perf_metrics.main_model_metrics.raw_metrics.m_durations) + assert total_iteration_number_main > 0 and total_iteration_number_main <= generation_config.max_new_tokens + + total_iteration_number_draft = len(extended_perf_metrics.draft_model_metrics.raw_metrics.m_durations) + assert total_iteration_number_draft > 0 and total_iteration_number_draft < ((generation_config.max_new_tokens - 1) * generation_config.num_assistant_tokens + 1) + + for model_metrics in [extended_perf_metrics.main_model_metrics, extended_perf_metrics.draft_model_metrics]: + mean_ttst, std_ttst = model_metrics.get_ttst() + assert (mean_ttst, std_ttst) == (model_metrics.get_ttst().mean, model_metrics.get_ttst().std) + assert mean_ttst > 0 and mean_ttst < model_metrics.get_ttft().mean + assert std_ttst == 0 + + mean_latency, std_latency = model_metrics.get_latency() + assert (mean_latency, std_latency) == (model_metrics.get_latency().mean, model_metrics.get_latency().std) + assert mean_latency > 0 and mean_latency < 1000.0 + + mean_gen_duration, std_gen_duration = model_metrics.get_generate_duration() + assert (mean_gen_duration, std_gen_duration) == (model_metrics.get_generate_duration().mean, model_metrics.get_generate_duration().std) + assert mean_gen_duration > 0 and mean_gen_duration < total_time + assert std_gen_duration == 0 \ No newline at end of file From 3a3ad17e0451d16a1e5d94e3d060f7d76a8b5cb6 Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 17 Oct 2025 00:35:38 +0800 Subject: [PATCH 23/43] refine hs state management Signed-off-by: fishbell --- .../src/continuous_batching/model_runner.hpp | 160 ++++++++++-------- src/cpp/src/sampling/sampler.hpp | 2 +- ...batching_for_speculative_decoding_impl.cpp | 4 +- ...batching_for_speculative_decoding_impl.hpp | 6 +- 4 files changed, 92 insertions(+), 80 deletions(-) diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index 77d9215324..dc271778b3 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -24,6 +24,13 @@ inline std::string get_paged_attention_score_output_for_decoder_layer(size_t dec return ss.str(); } +enum HiddenStateFlags : uint8_t { + HS_NONE = 0, + HS_EXPORT = 1 << 0, + HS_IMPORT = 1 << 1, + HS_INTERNAL = 1 << 2 +}; + /** * @brief Runs the LLM infer request, parsing the continuous batching scheduler output into proper inputs in terms of OV API (e.g. token input IDs, * KV cache block indices etc.) and returning the logit scores for the next token to be generated for each of the currently scheduled sequences. @@ -48,10 +55,7 @@ class ModelRunner { // Input shape: [N, conversation length]. // Output shape: [1, conversation length, hidden_size]. EmbeddingsModel::Ptr m_embedding; - - bool m_is_hidden_state_export_needed = false; // need to export hidden state after inference - bool m_is_hidden_state_import_needed = false; // need to import hidden state from another model runner - bool m_is_hidden_state_internal_needed = false; // need to use internal hidden state, e.g, eagle SD + uint8_t m_hidden_state_flags = HS_NONE; std::map, std::pair> m_sequence_hidden_state_mapping; // pre-requisite: main/draft have same seq group and running seq grouped id // a container which use sequence group id and request id as key to store hidden states std::map m_initial_hidden_states; // shape: [N, seq_len, hidden_size] @@ -112,17 +116,9 @@ class ModelRunner { return m_request; } - void set_hidden_state_export_needed(bool is_needed) { - m_is_hidden_state_export_needed = is_needed; - } - - void set_hidden_state_import_needed(bool is_needed) { - m_is_hidden_state_import_needed = is_needed; - } - - void set_hidden_state_internal_needed(bool is_needed) { - m_is_hidden_state_internal_needed = is_needed; - } + void enable_hidden_state_export(bool on) { on ? m_hidden_state_flags |= HS_EXPORT : m_hidden_state_flags &= ~HS_EXPORT; } + void enable_hidden_state_import(bool on) { on ? m_hidden_state_flags |= HS_IMPORT : m_hidden_state_flags &= ~HS_IMPORT; } + void enable_hidden_state_internal(bool on) { on ? m_hidden_state_flags |= HS_INTERNAL : m_hidden_state_flags &= ~HS_INTERNAL; } void set_embedding_model(const EmbeddingsModel::Ptr& embedder) { m_embedding = embedder; @@ -154,39 +150,6 @@ class ModelRunner { m_cache_rotation_deltas_for_each_layer = std::move(rotation_deltas_for_each_layer); } - ov::Tensor get_hidden_state(uint64_t request_id, uint64_t seq_grouped_id) const { - if (m_hidden_states.get_size() == 0) { - return ov::Tensor(); - } - - const auto key = std::make_pair(request_id, seq_grouped_id); - const auto it = m_sequence_hidden_state_mapping.find(key); - if (it == m_sequence_hidden_state_mapping.end()) { - return ov::Tensor(); - } - - size_t start_idx = it->second.first; - size_t length = it->second.second; - - auto shape = m_hidden_states.get_shape(); - if (shape.size() < 2) { - return ov::Tensor(); - } - - ov::Coordinate start_coord(shape.size(), 0); - ov::Coordinate end_coord(shape.size(), 0); - - start_coord[0] = start_idx; - end_coord[0] = start_idx + length; - - for (size_t i = 1; i < shape.size(); ++i) { - start_coord[i] = 0; - end_coord[i] = shape[i]; - } - - return ov::Tensor(m_hidden_states, start_coord, end_coord); - } - void set_initial_hidden_state(uint64_t& request_id, const ov::Tensor& hidden_state) { m_initial_hidden_states[request_id] = hidden_state; } @@ -247,26 +210,9 @@ class ModelRunner { {batch_size_in_sequences}, ov::element::i32); ov::Tensor hidden_state_input; float* hidden_state_data = nullptr; - if (m_is_hidden_state_import_needed || m_is_hidden_state_internal_needed) { - if (hidden_size == 0) { - for (const auto& entry : m_initial_hidden_states) { - const auto& stored_hidden_state = entry.second; - if (stored_hidden_state.get_size() > 0) { - auto shape = stored_hidden_state.get_shape(); - if (shape.size() >= 2) { - hidden_size = shape[shape.size() - 1]; - if (!m_is_hidden_state_import_needed) - hidden_size /= m_adjust_factor; - break; - } - } - } - } - if (hidden_size > 0) { - hidden_state_input = ov::Tensor(ov::element::f32, {total_num_tokens, 1, hidden_size}); - hidden_state_data = hidden_state_input.data(); - std::memset(hidden_state_data, 0, total_num_tokens * hidden_size * sizeof(float)); - } + hidden_state_input = _prepare_hidden_state_input(total_num_tokens, hidden_size); + if (hidden_state_input) { + hidden_state_data = hidden_state_input.data(); } ov::Tensor generated_ids_embeds; @@ -339,14 +285,14 @@ class ModelRunner { output_seq_len = 0; Sequence::CPtr sequence = running_sequences[seq_idx]; - if (m_is_hidden_state_export_needed) { + if (_is_hs_export()) { size_t start_token_idx = current_token_idx; size_t sequence_length = num_scheduled_tokens; auto key = std::make_pair(sequence_group->get_request_id(), sequence->get_grouped_id()); m_sequence_hidden_state_mapping[key] = std::make_pair(start_token_idx, sequence_length); } - if (m_is_hidden_state_import_needed && hidden_state_data && hidden_size > 0) { + if (_is_hs_import()) { auto it = m_initial_hidden_states.find(sequence_group->get_request_id()); if (it != m_initial_hidden_states.end()) { @@ -375,7 +321,7 @@ class ModelRunner { OPENVINO_ASSERT(false, "missing hidden state from target model to eagle draft model"); } } - } else { + } else if (_is_hs_internal()) { // fill hidden_state_data with m_hidden_states if (hidden_state_data) { OPENVINO_ASSERT(num_scheduled_tokens == 1, "unexpected num_scheduled_tokens in speculative drafting stage in eagle3 mode"); @@ -497,7 +443,7 @@ class ModelRunner { } } if (hidden_state_input && hidden_state_input.get_size() > 0) { - if (m_is_hidden_state_import_needed) { + if (_is_hs_import()) { try { m_request.set_tensor("hidden_states", hidden_state_input); auto shape = hidden_state_input.get_shape(); @@ -574,7 +520,7 @@ class ModelRunner { _reset_cache_rotation_coefficients(); - if (m_is_hidden_state_export_needed) { + if (_is_hs_export()) { try { m_hidden_states = m_request.get_tensor("last_hidden_state"); for (size_t i = 0; i < num_sequence_groups; ++i) { @@ -584,7 +530,7 @@ class ModelRunner { for (size_t seq_idx = 0; seq_idx < running_sequences.size(); ++seq_idx) { Sequence::Ptr sequence = running_sequences[seq_idx]; sequence->update_hidden_state( - get_hidden_state(sequence_group->get_request_id(), sequence->get_grouped_id())); + _get_hidden_state(sequence_group->get_request_id(), sequence->get_grouped_id())); } } } catch (const ov::Exception&) { @@ -655,6 +601,72 @@ class ModelRunner { private: ov::Tensor m_hidden_states; + // Hidden state flags and helpers + bool _is_hs_export() const { return m_hidden_state_flags & HS_EXPORT; } + bool _is_hs_import() const { return m_hidden_state_flags & HS_IMPORT; } + bool _is_hs_internal() const { return m_hidden_state_flags & HS_INTERNAL; } + + ov::Tensor _get_hidden_state(uint64_t request_id, uint64_t seq_grouped_id) const { + if (m_hidden_states.get_size() == 0) { + return ov::Tensor(); + } + + const auto key = std::make_pair(request_id, seq_grouped_id); + const auto it = m_sequence_hidden_state_mapping.find(key); + if (it == m_sequence_hidden_state_mapping.end()) { + return ov::Tensor(); + } + + size_t start_idx = it->second.first; + size_t length = it->second.second; + + auto shape = m_hidden_states.get_shape(); + if (shape.size() < 2) { + return ov::Tensor(); + } + + ov::Coordinate start_coord(shape.size(), 0); + ov::Coordinate end_coord(shape.size(), 0); + + start_coord[0] = start_idx; + end_coord[0] = start_idx + length; + + for (size_t i = 1; i < shape.size(); ++i) { + start_coord[i] = 0; + end_coord[i] = shape[i]; + } + + return ov::Tensor(m_hidden_states, start_coord, end_coord); + } + + ov::Tensor _prepare_hidden_state_input(size_t total_num_tokens, + size_t& hidden_size /*in/out*/) { + if (!(m_hidden_state_flags & (HS_IMPORT | HS_INTERNAL))) { + return {}; + } + + if (hidden_size == 0) { + for (const auto& kv : m_initial_hidden_states) { + const auto& t = kv.second; + if (t && t.get_shape().size() >= 2) { + auto sh = t.get_shape(); + hidden_size = sh.back(); + if (!(m_hidden_state_flags & HS_IMPORT)) { + hidden_size /= m_adjust_factor; + } + break; + } + } + } + if (hidden_size == 0) { + return {}; + } + + ov::Tensor hs(ov::element::f32, {total_num_tokens, 1, hidden_size}); + std::memset(hs.data(), 0, hs.get_byte_size()); + return hs; + } + // Common helper to copy a contiguous slice (first-dim range) from src to dst using ROI tensors. // src_start_idx: start index along src first dimension // copy_length: number of elements along first dim to copy diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index 96cf497793..6ba02c3b06 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -99,7 +99,7 @@ class Sampler { Tokenizer m_tokenizer; ThreadPool m_thread_pool; - std::shared_ptr m_d2t; // Tensor to store d2t mapping for eagle model + std::shared_ptr m_d2t; // Tensor to store draft2target mapping for eagle model public: Sampler(const Sampler& rhs) = delete; Sampler(Sampler&& rhs) = delete; diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index 3dd2ef2d6d..a3a3169023 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -386,7 +386,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m } if (eagle_mode_enabled) - m_model_runner->set_hidden_state_import_needed(false); + m_model_runner->enable_hidden_state_import(false); to_generate = false; for (auto& request : m_requests) { const auto& sampling_params = request->get_sampling_parameters(); @@ -411,6 +411,6 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m } } if (eagle_mode_enabled) - m_model_runner->set_hidden_state_import_needed(true); + m_model_runner->enable_hidden_state_import(true); } } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index b048eafe2f..a4db89a437 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -70,19 +70,19 @@ class ContinuousBatchingPipeline::ContinuousBatchingForEagleDecodingImpl } void set_hidden_state_export_needed(bool is_needed) { if (m_model_runner) { - m_model_runner->set_hidden_state_export_needed(is_needed); + m_model_runner->enable_hidden_state_export(is_needed); } } void set_hidden_state_import_needed(bool is_needed) { if (m_model_runner) { - m_model_runner->set_hidden_state_import_needed(is_needed); + m_model_runner->enable_hidden_state_import(is_needed); } } void set_hidden_state_internal_needed(bool is_needed) { if (m_model_runner) { - m_model_runner->set_hidden_state_internal_needed(is_needed); + m_model_runner->enable_hidden_state_internal(is_needed); } } From b3330455d89decc06a3ece2c2726fde9703a7aaa Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 17 Oct 2025 16:43:23 +0800 Subject: [PATCH 24/43] apply copilot comment Signed-off-by: fishbell --- src/cpp/src/sequence_group.hpp | 2 +- src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp | 2 +- tests/python_tests/test_eagle3.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 5ae6767111..d9d04e20b7 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -140,7 +140,7 @@ class Sequence { m_hidden_state = tensor; } - ov::Tensor& get_hidden_state() { + ov::Tensor& get_hidden_state() const { return m_hidden_state; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 179f9a8e82..7869476118 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -413,7 +413,7 @@ void share_embedding_weights(std::shared_ptr& main_model, std::shared try { draft_weight_node->output(0).replace(main_weight_node->output(0)); } catch (...) { - std::cout << "fail to import embedding weights from main model to draft model" << std::endl; + std::cout << "failed to import embedding weights from main model to draft model" << std::endl; } } diff --git a/tests/python_tests/test_eagle3.py b/tests/python_tests/test_eagle3.py index 2c525a66f6..1c5bcff7d8 100644 --- a/tests/python_tests/test_eagle3.py +++ b/tests/python_tests/test_eagle3.py @@ -22,7 +22,6 @@ def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_device, prompt): # Download and convert model: main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) - print("finised") __, __, draft_model_path = download_and_convert_model(draft_model) # Create OpenVINO GenAI pipeline: From 62320a40ccdfcdec521c707410bab5b415b223f9 Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 17 Oct 2025 17:54:57 +0800 Subject: [PATCH 25/43] fix build failure Signed-off-by: fishbell --- src/cpp/src/sequence_group.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index d9d04e20b7..5680ada251 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -140,7 +140,7 @@ class Sequence { m_hidden_state = tensor; } - ov::Tensor& get_hidden_state() const { + ov::Tensor get_hidden_state() const { return m_hidden_state; } From 84b589e7ad7347300134f72248b9f1aad34a169c Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 17 Oct 2025 21:36:11 +0800 Subject: [PATCH 26/43] fallback unchanged file Signed-off-by: fishbell --- tests/python_tests/test_continuous_batching.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index 9dacdfde51..2c5f0c35dc 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -474,6 +474,7 @@ def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_c raise RuntimeError(f"{pipeline_type} is unknown pipeline type!") return pipe, prompt, generation_config + def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType): _, _, model_path = download_and_convert_model(model_id) ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type) From baa890b7fe799e232d932d5ac7fed97038f802a4 Mon Sep 17 00:00:00 2001 From: yanlan song Date: Mon, 20 Oct 2025 15:26:55 +0800 Subject: [PATCH 27/43] Update site/docs/supported-models/_components/llm-models-table/models.ts Co-authored-by: Roman Kazantsev --- .../supported-models/_components/llm-models-table/models.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/site/docs/supported-models/_components/llm-models-table/models.ts b/site/docs/supported-models/_components/llm-models-table/models.ts index 2fce43c9b2..57619b6f3c 100644 --- a/site/docs/supported-models/_components/llm-models-table/models.ts +++ b/site/docs/supported-models/_components/llm-models-table/models.ts @@ -493,8 +493,6 @@ export const LLM_MODELS: LLMModelType[] = [ 'https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct', 'https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B', 'https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B', - 'https://huggingface.co/yuhuili/EAGLE3-LLaMA3.1-Instruct-8B', - 'https://huggingface.co/Tengyunw/qwen3_8b_eagle3' ], }, { From c5e474b9f762772652d63ca32667a575b27d5496 Mon Sep 17 00:00:00 2001 From: fishbell Date: Mon, 20 Oct 2025 17:00:17 +0800 Subject: [PATCH 28/43] apply copilot Signed-off-by: fishbell --- .../speculative_decoding_lm.cpp | 28 +++++++++++++++---- .../src/continuous_batching/model_runner.hpp | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index 093c62973f..803cc91c03 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -5,20 +5,25 @@ #include "openvino/genai/llm_pipeline.hpp" #include "openvino/genai/speculative_decoding/perf_metrics.hpp" +#include "read_prompt_from_file.h" int main(int argc, char* argv[]) try { if (4 != argc) { throw std::runtime_error(std::string{"Usage: "} + argv[0] + " ''"); } - + std::string prompt = argv[3]; + if (std::filesystem::is_regular_file(prompt)) { + std::string prompt_file = prompt; + prompt = utils::read_prompt(prompt_file); + } ov::genai::GenerationConfig config; - config.max_new_tokens = 100; + config.max_new_tokens = 129; // Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded. // Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration. // NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial // value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both // backends. - config.num_assistant_tokens = 4; + config.num_assistant_tokens = 5; // Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than // `assistant_confidence_threshold`. // NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend. @@ -26,13 +31,12 @@ int main(int argc, char* argv[]) try { std::string main_model_path = argv[1]; std::string draft_model_path = argv[2]; - std::string prompt = argv[3]; // User can run main and draft model on different devices. // Please, set device for main model in `LLMPipeline` constructor and in `ov::genai::draft_model` for draft. // CPU, GPU and NPU can be used. For NPU, the preferred configuration is when both the main and draft models // use NPU. - std::string main_device = "CPU", draft_device = "CPU"; + std::string main_device = "GPU", draft_device = "GPU"; ov::genai::LLMPipeline pipe( main_model_path, @@ -48,7 +52,19 @@ int main(int argc, char* argv[]) try { // be printed each time a new token is generated. auto result = pipe.generate(prompt, config, streamer); auto sd_perf_metrics = std::dynamic_pointer_cast(result.extended_perf_metrics); - + auto perf_metrics = result.perf_metrics; + std::cout << "\n\nPERF METRICS: " << std::endl; + auto generation_duration = perf_metrics.get_generate_duration().mean; + std::cout << " Generate time: " << generation_duration << " ms" << std::endl; + std::cout << " TTFT: " << perf_metrics.get_ttft().mean << " ± " << perf_metrics.get_ttft().std << " ms" + << std::endl; + std::cout << " TPOT: " << perf_metrics.get_tpot().mean << " ± " << perf_metrics.get_tpot().std << " ms/token" + << std::endl; + std::cout << " Num generated token: " << perf_metrics.get_num_generated_tokens() << " tokens" << std::endl; + std::cout << " Total iteration number: " << perf_metrics.raw_metrics.m_new_token_times.size() << std::endl; + if (perf_metrics.get_num_input_tokens() > 0) { + std::cout << " Input token size: " << perf_metrics.get_num_input_tokens() << std::endl; + } if (sd_perf_metrics) { auto main_model_metrics = sd_perf_metrics->main_model_metrics; std::cout << "\nMAIN MODEL " << std::endl; diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index dc271778b3..b4f2374a21 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -150,7 +150,7 @@ class ModelRunner { m_cache_rotation_deltas_for_each_layer = std::move(rotation_deltas_for_each_layer); } - void set_initial_hidden_state(uint64_t& request_id, const ov::Tensor& hidden_state) { + void set_initial_hidden_state(uint64_t request_id, const ov::Tensor& hidden_state) { m_initial_hidden_states[request_id] = hidden_state; } From 75ea28e90b88f170d7833cb79856dd7ef07b2573 Mon Sep 17 00:00:00 2001 From: fishbell Date: Mon, 20 Oct 2025 17:35:13 +0800 Subject: [PATCH 29/43] fix typo for samples Signed-off-by: fishbell --- .../speculative_decoding_lm.cpp | 28 ++++--------------- .../speculative_decoding_impl.cpp | 4 ++- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index 803cc91c03..093c62973f 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -5,25 +5,20 @@ #include "openvino/genai/llm_pipeline.hpp" #include "openvino/genai/speculative_decoding/perf_metrics.hpp" -#include "read_prompt_from_file.h" int main(int argc, char* argv[]) try { if (4 != argc) { throw std::runtime_error(std::string{"Usage: "} + argv[0] + " ''"); } - std::string prompt = argv[3]; - if (std::filesystem::is_regular_file(prompt)) { - std::string prompt_file = prompt; - prompt = utils::read_prompt(prompt_file); - } + ov::genai::GenerationConfig config; - config.max_new_tokens = 129; + config.max_new_tokens = 100; // Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded. // Add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration. // NOTE: ContinuousBatching backend uses `num_assistant_tokens` as is. Stateful backend uses `num_assistant_tokens`'s copy as initial // value and adjusts it based on recent number of accepted tokens. If `num_assistant_tokens` is not set, it defaults to `5` for both // backends. - config.num_assistant_tokens = 5; + config.num_assistant_tokens = 4; // Add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than // `assistant_confidence_threshold`. // NOTE: `assistant_confidence_threshold` is supported only by ContinuousBatching backend. @@ -31,12 +26,13 @@ int main(int argc, char* argv[]) try { std::string main_model_path = argv[1]; std::string draft_model_path = argv[2]; + std::string prompt = argv[3]; // User can run main and draft model on different devices. // Please, set device for main model in `LLMPipeline` constructor and in `ov::genai::draft_model` for draft. // CPU, GPU and NPU can be used. For NPU, the preferred configuration is when both the main and draft models // use NPU. - std::string main_device = "GPU", draft_device = "GPU"; + std::string main_device = "CPU", draft_device = "CPU"; ov::genai::LLMPipeline pipe( main_model_path, @@ -52,19 +48,7 @@ int main(int argc, char* argv[]) try { // be printed each time a new token is generated. auto result = pipe.generate(prompt, config, streamer); auto sd_perf_metrics = std::dynamic_pointer_cast(result.extended_perf_metrics); - auto perf_metrics = result.perf_metrics; - std::cout << "\n\nPERF METRICS: " << std::endl; - auto generation_duration = perf_metrics.get_generate_duration().mean; - std::cout << " Generate time: " << generation_duration << " ms" << std::endl; - std::cout << " TTFT: " << perf_metrics.get_ttft().mean << " ± " << perf_metrics.get_ttft().std << " ms" - << std::endl; - std::cout << " TPOT: " << perf_metrics.get_tpot().mean << " ± " << perf_metrics.get_tpot().std << " ms/token" - << std::endl; - std::cout << " Num generated token: " << perf_metrics.get_num_generated_tokens() << " tokens" << std::endl; - std::cout << " Total iteration number: " << perf_metrics.raw_metrics.m_new_token_times.size() << std::endl; - if (perf_metrics.get_num_input_tokens() > 0) { - std::cout << " Input token size: " << perf_metrics.get_num_input_tokens() << std::endl; - } + if (sd_perf_metrics) { auto main_model_metrics = sd_perf_metrics->main_model_metrics; std::cout << "\nMAIN MODEL " << std::endl; diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 7869476118..bd51267287 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -412,8 +412,10 @@ void share_embedding_weights(std::shared_ptr& main_model, std::shared try { draft_weight_node->output(0).replace(main_weight_node->output(0)); + } catch (const std::exception& e) { + std::cerr << "Error: failed to import embedding weights from main model to draft model. Exception: " << e.what() << std::endl; } catch (...) { - std::cout << "failed to import embedding weights from main model to draft model" << std::endl; + std::cerr << "Error: failed to import embedding weights from main model to draft model due to unknown exception." << std::endl; } } From fdeb3f0cb6099c41b65f5df61b71ccb102c79560 Mon Sep 17 00:00:00 2001 From: fishbell Date: Tue, 21 Oct 2025 18:23:48 +0800 Subject: [PATCH 30/43] apply review comments part-1, mainly on code Signed-off-by: fishbell --- .../genai/continuous_batching_pipeline.hpp | 8 +- .../src/continuous_batching/model_runner.hpp | 12 +- src/cpp/src/continuous_batching/pipeline.cpp | 13 +- src/cpp/src/llm/pipeline.cpp | 1 - src/cpp/src/lora/adapter.cpp | 8 + src/cpp/src/safe_tensor_wrapper.cpp | 8 +- src/cpp/src/safe_tensor_wrapper.hpp | 8 +- src/cpp/src/sampling/sampler.cpp | 4 +- src/cpp/src/sampling/sampler.hpp | 4 +- src/cpp/src/sequence_group.hpp | 2 +- ...batching_for_speculative_decoding_impl.hpp | 6 +- .../speculative_decoding_eagle3_impl.cpp | 517 ++++++++++++++++++ .../speculative_decoding_eagle3_impl.hpp | 103 ++++ .../speculative_decoding_impl.cpp | 511 ----------------- .../speculative_decoding_impl.hpp | 96 +--- .../update_request_structs.hpp | 2 +- 16 files changed, 658 insertions(+), 645 deletions(-) create mode 100644 src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp create mode 100644 src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp diff --git a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp index 7ab19c07aa..61bab2b6b5 100644 --- a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp +++ b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp @@ -65,18 +65,18 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline { class ContinuousBatchingImpl; class ContinuousBatchingForSpeculativeDecodingImpl; - class ContinuousBatchingForEagleDecodingImpl; + class ContinuousBatchingForEagle3DecodingImpl; class ContinuousBatchingForPromptLookupImpl; class SpeculativeDecodingImpl; - class EagleDecodingImpl; + class Eagle3DecodingImpl; class PromptLookupImpl; friend class ContinuousBatchingForSpeculativeDecodingImpl; friend class ContinuousBatchingForPromptLookupImpl; - friend class ContinuousBatchingForEagleDecodingImpl; + friend class ContinuousBatchingForEagle3DecodingImpl; friend class SpeculativeDecodingImpl; - friend class EagleDecodingImpl; + friend class Eagle3DecodingImpl; friend class PromptLookupImpl; std::shared_ptr m_impl; diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index b4f2374a21..f89600c5be 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -208,9 +208,9 @@ class ModelRunner { {1, total_num_tokens}, ov::element::i64); ov::Tensor score_aggregation_window = _get_or_resize_tensor(m_cached_score_aggregation_window, "score_aggregation_window", {batch_size_in_sequences}, ov::element::i32); - ov::Tensor hidden_state_input; + + ov::Tensor hidden_state_input = _prepare_hidden_state_input(total_num_tokens, hidden_size); float* hidden_state_data = nullptr; - hidden_state_input = _prepare_hidden_state_input(total_num_tokens, hidden_size); if (hidden_state_input) { hidden_state_data = hidden_state_input.data(); } @@ -647,10 +647,10 @@ class ModelRunner { if (hidden_size == 0) { for (const auto& kv : m_initial_hidden_states) { - const auto& t = kv.second; - if (t && t.get_shape().size() >= 2) { - auto sh = t.get_shape(); - hidden_size = sh.back(); + const auto& initial_hidden_states = kv.second; + if (initial_hidden_states && initial_hidden_states.get_shape().size() >= 2) { + auto hidden_states_shape = initial_hidden_states.get_shape(); + hidden_size = hidden_states_shape.back(); if (!(m_hidden_state_flags & HS_IMPORT)) { hidden_size /= m_adjust_factor; } diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index 961be331ca..686771e5c0 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -11,6 +11,7 @@ #include "openvino/genai/tokenizer.hpp" #include "continuous_batching/pipeline_impl.hpp" #include "speculative_decoding/speculative_decoding_impl.hpp" +#include "speculative_decoding/speculative_decoding_eagle3_impl.hpp" #include "prompt_lookup/prompt_lookup_impl.hpp" #include "continuous_batching/timer.hpp" #include "utils.hpp" @@ -114,12 +115,12 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p // Use scheduler_config_copy in subsequent code if modification is needed } auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config_copy, generation_config); - m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); // parse d2t from safe tensors if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional - std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); } } } else if (draft_model_desr.model != nullptr) { @@ -173,12 +174,12 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( // Use scheduler_config_copy in subsequent code if modification is needed } auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config_copy, generation_config); - m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); // parse d2t from safe tensors if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional - std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); } } } else if (draft_model_desr.model != nullptr) { @@ -237,12 +238,12 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( // Use scheduler_config_copy in subsequent code if modification is needed } auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config_copy, generation_config); - m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); // parse d2t from safe tensors if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional - std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); + std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); } } } else if (draft_model_desr.model != nullptr) { diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index b1667a1628..b3bfddbc05 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -36,7 +36,6 @@ std::pair generation_config(const GenerationConfig& config) { inline void apply_eagle_rt_info(std::shared_ptr& model, ov::AnyMap& properties, const std::filesystem::path& mapping_path) { if (model->has_rt_info("eagle3_mode") && model->get_rt_info("eagle3_mode")) { - std::cout << "Eagle3 model is detected from rt_info. Applying eagle3_mode property." << std::endl; properties["eagle3_mode"] = true; if (model->has_rt_info("hidden_layers_list")) properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); diff --git a/src/cpp/src/lora/adapter.cpp b/src/cpp/src/lora/adapter.cpp index 1225912d46..fb4e427b2f 100644 --- a/src/cpp/src/lora/adapter.cpp +++ b/src/cpp/src/lora/adapter.cpp @@ -66,6 +66,14 @@ using ConstantVector = std::vector>; using LoRANode = LoRAParts>; using LoRAPartsParser = LoRAParts(const std::string& name)>>; +// Reads a file with a given filename expecting Safetensors file format. +// The file data is mmaped to tensor. +ConstantMap read_safetensors(const std::filesystem::path& filename) { + auto safetensor = ov::read_tensor_data(filename); + + return safetensor_to_constant_map(safetensor); +} + // Default LoRA tensor name patterns observed in the existing LoRA adapters, captures the prefix that should correspond // to a layer name in the base model LoRAPartsParser default_lora_patterns () { diff --git a/src/cpp/src/safe_tensor_wrapper.cpp b/src/cpp/src/safe_tensor_wrapper.cpp index c46adc863a..734a08a90e 100644 --- a/src/cpp/src/safe_tensor_wrapper.cpp +++ b/src/cpp/src/safe_tensor_wrapper.cpp @@ -35,15 +35,9 @@ ConstantMap safetensor_to_constant_map(const ov::Tensor& safetensor) { auto type = safetensors_to_ov_element_type(tensor.dtype); auto constant = - std::make_shared(type, shape, ptr, nullptr); // wraps existing memory, no ownership + std::make_shared(type, shape, ptr, nullptr); // wraps existing memory, no ownership constant->get_rt_info()["__safetensors_buffer_holder"] = safetensor; // to automatically deallocate underlying memory buffer when last constant that holds it is destroyed tensors[name] = constant; } return tensors; -} - -ConstantMap read_safetensors(const std::filesystem::path& filename) { - auto safetensor = ov::read_tensor_data(filename); - - return safetensor_to_constant_map(safetensor); } \ No newline at end of file diff --git a/src/cpp/src/safe_tensor_wrapper.hpp b/src/cpp/src/safe_tensor_wrapper.hpp index 074ce3f1dd..dbb5f2c292 100644 --- a/src/cpp/src/safe_tensor_wrapper.hpp +++ b/src/cpp/src/safe_tensor_wrapper.hpp @@ -2,11 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 #include "openvino/runtime/core.hpp" #include "openvino/op/constant.hpp" + extern "C" { #include "safetensors.h" } -using namespace ov::op; // Converts Safetensors element type to OV element type. Only part of the types are supported. ov::element::Type safetensors_to_ov_element_type (int dtype); @@ -24,8 +24,4 @@ struct AutoSafetensor: public safetensors_File { // The key in the map is a tensor name and the Constant uses a region of memory from the memory block. // Each Constant holds a shared pointer to the block in the runtime info. // The memory block will be deallocated when the last Constant is destroyed. -ConstantMap safetensor_to_constant_map(const ov::Tensor& safetensor); - -// Reads a file with a given filename expecting Safetensors file format. -// The file data is mmaped to tensor. -ConstantMap read_safetensors(const std::filesystem::path& filename); \ No newline at end of file +ConstantMap safetensor_to_constant_map(const ov::Tensor& safetensor); \ No newline at end of file diff --git a/src/cpp/src/sampling/sampler.cpp b/src/cpp/src/sampling/sampler.cpp index e8363fa792..2fd21bf0d0 100644 --- a/src/cpp/src/sampling/sampler.cpp +++ b/src/cpp/src/sampling/sampler.cpp @@ -853,8 +853,8 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr } } } - if (!is_validation_mode_enabled && m_d2t) { // compute token offset for draft model in speculative sampling - ov::Tensor d2t_tensor = m_d2t->get_tensor_view(); + if (!is_validation_mode_enabled && m_draft2target_mapping) { // compute token offset for draft model in speculative sampling + ov::Tensor d2t_tensor = m_draft2target_mapping->get_tensor_view(); auto d2t = d2t_tensor.data(); sampled_token.m_index = sampled_token.m_index + (d2t? d2t[sampled_token.m_index] : 0); } diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index 6ba02c3b06..c4def2f871 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -99,7 +99,7 @@ class Sampler { Tokenizer m_tokenizer; ThreadPool m_thread_pool; - std::shared_ptr m_d2t; // Tensor to store draft2target mapping for eagle model + std::shared_ptr m_draft2target_mapping; // Tensor to store draft2target mapping for eagle model public: Sampler(const Sampler& rhs) = delete; Sampler(Sampler&& rhs) = delete; @@ -128,7 +128,7 @@ class Sampler { void clear_structured_output_compile_times(); void set_d2t_for_decoding(std::shared_ptr& d2t) { - m_d2t = d2t; + m_draft2target_mapping = d2t; }; }; diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 5680ada251..9888ba3382 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -136,7 +136,7 @@ class Sequence { m_generated_ids.push_back(token_id); } - void update_hidden_state(ov::Tensor tensor) { + void update_hidden_state(const ov::Tensor& tensor) { m_hidden_state = tensor; } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index a4db89a437..5d6d220028 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -43,12 +43,12 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : bool eagle_mode_enabled = false; }; -class ContinuousBatchingPipeline::ContinuousBatchingForEagleDecodingImpl +class ContinuousBatchingPipeline::ContinuousBatchingForEagle3DecodingImpl : public ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl { public: - ContinuousBatchingForEagleDecodingImpl() = default; + ContinuousBatchingForEagle3DecodingImpl() = default; - ContinuousBatchingForEagleDecodingImpl(const std::shared_ptr& model, + ContinuousBatchingForEagle3DecodingImpl(const std::shared_ptr& model, const Tokenizer& tokenizer, const GenerationConfig& generation_config, const SchedulerConfig& scheduler_config, diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp new file mode 100644 index 0000000000..2a6f26b134 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -0,0 +1,517 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +#include "speculative_decoding_eagle3_impl.hpp" + +namespace ov::genai { +void share_embedding_weights(std::shared_ptr& main_model, std::shared_ptr& draft_model) { + // extract embedding weight from main model + auto find_embedding_gather = [](const std::shared_ptr& model) + -> std::shared_ptr { + for (const auto& node : model->get_ordered_ops()) { + auto gather = std::dynamic_pointer_cast(node); + if (!gather) continue; + // [vocab, hidden_size] * [batch, seq_len] -> [batch, seq_len, hidden_size] + auto data_node = gather->input_value(0).get_node_shared_ptr(); + auto indices_node = gather->input_value(1).get_node_shared_ptr(); + if (!data_node || !indices_node) continue; + // indices_node should be on parameter path, maybe this is better rule + ov::PartialShape ps = data_node->get_output_partial_shape(0); + if (ps.rank().is_static() && ps.rank().get_length() >= 2) { + if (ps[0].is_static() && ps[0].get_length() > 1000) { // Heuristic: vocab size > 1000 + return gather; + } + } + std::string fname = data_node->get_friendly_name(); + if (fname.find("embed_tokens") != std::string::npos || + fname.find("embedding") != std::string::npos) { + return gather; + } + } + return nullptr; + }; + auto main_gather = find_embedding_gather(main_model); + auto draft_gather = find_embedding_gather(draft_model); + if (!main_gather || !draft_gather) { + return; + } + auto main_weight_node = main_gather->input_value(0).get_node_shared_ptr(); + auto draft_weight_node = draft_gather->input_value(0).get_node_shared_ptr(); + + if (main_weight_node.get() == draft_weight_node.get()) { + return; + } + + try { + draft_weight_node->output(0).replace(main_weight_node->output(0)); + } catch (const std::exception& e) { + std::cerr << "Error: failed to import embedding weights from main model to draft model. Exception: " << e.what() << std::endl; + } catch (...) { + std::cerr << "Error: failed to import embedding weights from main model to draft model due to unknown exception." << std::endl; + } +} + +void extract_hidden_state_generic(std::shared_ptr& model, + const std::vector& hidden_layers_to_abstract) { + ov::pass::Manager pm; + pm.register_pass(hidden_layers_to_abstract); + pm.run_passes(model); +} + +EagleModelTransform::EagleModelTransform(const std::vector& layers) : m_layer_ids(layers) { +} + +bool EagleModelTransform::run_on_model(const std::shared_ptr& model) { + // share the embedding weights from main model to draft model + m_new_parameters.clear(); + m_new_results.clear(); + if (m_layer_ids.size() == 1 && m_layer_ids[0] == -1) { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(m_new_results); + // input transform for draft + // here we apply a trick for the fc layer in draft model + manager.register_pass(m_new_parameters); + manager.run_passes(model); + + model->add_parameters(m_new_parameters); + model->add_results(m_new_results); + return true; + } else { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(m_layer_ids, m_hidden_layer_outputs); + manager.run_passes(model); + + if (!m_hidden_layer_outputs.empty()) { + auto concat = std::make_shared(m_hidden_layer_outputs, -1); + concat->set_friendly_name("eagle3_hidden_states_concat"); + + auto result = std::make_shared(concat); + std::string output_name = "last_hidden_state"; + result->output(0).set_names({output_name}); + result->set_friendly_name(output_name); + model->add_results({result}); + return true; + } + } + + return false; +} + +EagleInputTransform::EagleInputTransform(std::vector>& params) { + register_matcher( + std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), + ([¶ms, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, params)) { + ++applied; + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "EagleTransform failed to apply"); + } + return false; + }) + ); +} +bool EagleInputTransform::apply(NodePtr node, std::vector>& params) { + if (ov::is_type(node)) { + auto matmul_node = ov::as_type_ptr(node); + // check the input of matmul node, if it is a node with name "hidden_states", then it's the node we want + auto input_node = matmul_node->get_input_node_shared_ptr(0); + if (!ov::as_type_ptr(input_node)) { + return false; + } + + auto shape = node->get_output_partial_shape(0); + auto internal_hidden_state = std::make_shared(node->get_element_type(), node->get_output_partial_shape(0)); + internal_hidden_state->output(0).set_names({"internal_hidden_states"}); + internal_hidden_state->set_friendly_name("internal_hidden_states"); + // create new eltwise node to add output of MatMul node and internal hidden state input from last cycle of itself + auto new_eltwise = std::make_shared(internal_hidden_state, matmul_node->output(0)); + ov::replace_node(matmul_node, new_eltwise); + params.push_back(internal_hidden_state); + return true; + } +} + +EagleBaseTransform::EagleBaseTransform(std::vector>& results) { + register_matcher( + std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), + ([&results, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, results)) { + ++applied; + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "EagleTransform failed to apply"); + } + return false; + }) + ); +} + +std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node, + std::set& visited_nodes) { + if (visited_nodes.count(start_node.get())) { + return nullptr; + } + + visited_nodes.insert(start_node.get()); + + if (ov::is_type(start_node)) { + // check the input nodes of MatMul, if found Gather node, return the gather node, otherwise ,retrun the matmul node + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + if (ov::as_type_ptr(input_node)) { + return start_node; // return the Add node itself + } + } + } + + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + + auto result = find_last_residual_node(input_node, visited_nodes); + if (result) { + return result; + } + } + return nullptr; +} + +std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node) { + std::set visited_nodes; + return find_last_residual_node(start_node, visited_nodes); +} + +bool EagleBaseTransform::apply(NodePtr node, std::vector>& results) { + { + // 1. without normalization layer 2. add extra input + if (ov::is_type(node)) { + // we are applying transformation to the last hidden state, eagle2 mode + NodePtr input_node = node->get_input_node_shared_ptr(0); + if (!input_node) { + return false; + } + auto last_residual_node = find_last_residual_node(input_node); + if (!last_residual_node) { + return false; + } + // + auto result = std::make_shared(last_residual_node); + std::string output_name = "last_hidden_state"; + result->output(0).set_names({output_name}); + result->set_friendly_name(output_name); + results.push_back(result); + return true; + } + return false; + } +} + +Eagle3Transform::Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs) : m_layers(layers) { + auto is_target_pattern = [&](const Output& output) { + auto add_node = ov::as_type_ptr(output.get_node_shared_ptr()); + auto add_node_name = add_node->get_friendly_name(); + if (add_node_name.find("self_attn") != std::string::npos) + return false; // Skip self-attention layers + bool layer_matched = false; + for (auto layer_idx : m_layers) { + if (add_node_name.find("layers." + std::to_string(layer_idx) + "/") != std::string::npos) { + layer_matched = true; + break; + } + } + + if (!layer_matched) { + return false; // Skip layers that are not in the specified layers + } + auto input0 = add_node->get_input_node_shared_ptr(1); + if (!input0 || !ov::is_type(input0)) { + return false; + } + auto matmul_node = input0; + auto matmul_input = matmul_node->get_input_node_shared_ptr(0); + if (!matmul_input) { + return false; + } + + bool has_multiply = ov::is_type(matmul_input); // ACT(up) dot gate + return has_multiply; + }; + + auto hidden_layer = ov::pass::pattern::wrap_type(is_target_pattern); + register_matcher(std::make_shared(hidden_layer, "Eagle3Transform::hidden_extraction"), + [&hidden_state_outputs, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + if (ov::is_type(node)) { + hidden_state_outputs.push_back(node->output(0)); + return true; + } + return false; + } + ); +} + +ContinuousBatchingPipeline::Eagle3DecodingImpl::Eagle3DecodingImpl(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc, + const std::vector& hidden_layers) + : m_hidden_layers_to_abstract(hidden_layers) { + auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); + auto main_model = main_model_desc.model; + auto draft_model = draft_model_desc.model; + + auto main_device = main_model_desc.device; + std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; + + ov::AnyMap draft_properties = + draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; + + // main and draft model can have different tokenizers + // to do: support retokenization: 154103 + Tokenizer main_model_tokenizer = main_model_desc.tokenizer; + Tokenizer draft_model_tokenizer = draft_model_desc.tokenizer; + m_tokenizer = main_model_tokenizer; + // for eagle model, we need to obtain hidden layer state as extra output + // apply transformations needed to run eagle model + // target model: hidden state extraction, draft model: hidden state import , hidden state extraction + // eagle3 specific : dt importing + share_embedding_weights(main_model, draft_model); + extract_hidden_state_generic(main_model, hidden_layers); + extract_hidden_state_generic(draft_model, { -1 }); + + // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode + m_main_pipeline = std::make_shared(main_model, + main_model_tokenizer, + main_model_desc.generation_config, + scheduler_configs.first, + main_device, + main_model_desc.properties, + true); + m_draft_pipeline = std::make_shared(draft_model, + draft_model_tokenizer, + draft_model_desc.generation_config, + scheduler_configs.second, + draft_device, + draft_properties, + false); + m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + m_perf_metrics.raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; + m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; +} + +ov::Tensor ContinuousBatchingPipeline::Eagle3DecodingImpl::create_draft_input_ids(const ov::Tensor& original_input_ids) { + auto shape = original_input_ids.get_shape(); + if (shape.size() == 0 || shape.back() <= 1) { + return ov::Tensor(original_input_ids); + } + + size_t original_length = shape.back(); + size_t new_length = original_length - 1; + + ov::Tensor draft_input_ids(ov::element::i64, {1, new_length}); + + const int64_t* src_data = original_input_ids.data(); + int64_t* dst_data = draft_input_ids.data(); + + std::copy(src_data + 1, src_data + original_length, dst_data); + + return draft_input_ids; +} + +void ContinuousBatchingPipeline::Eagle3DecodingImpl::update_eagle_pipeline_params() { + auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); + auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); + m_main_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_import_needed(true); + m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); + m_draft_eagle_pipeline->set_adjust_factor(m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); +} + +GenerationHandle +ContinuousBatchingPipeline::Eagle3DecodingImpl::add_request(uint64_t request_id, + const ov::Tensor& input_ids, + ov::genai::GenerationConfig sampling_params, + std::optional token_type_ids) { + std::lock_guard lock(m_draft_generations_mutex); + auto draft_sampling_params = sampling_params; + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + update_eagle_pipeline_params(); + // remove first token from input_ids to create draft_input_ids + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params, token_type_ids)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params, token_type_ids); +} + +GenerationHandle +ContinuousBatchingPipeline::Eagle3DecodingImpl::add_request(uint64_t request_id, + const std::string& prompt, + ov::genai::GenerationConfig sampling_params) { + std::lock_guard lock(m_draft_generations_mutex); + auto draft_sampling_params = sampling_params; + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + update_eagle_pipeline_params(); + // remove first token from input_ids to create draft_input_ids + if (m_model_input_type == ModelInputType::TOKENS) { + static ManualTimer timer("tokenize"); + timer.start(); + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + auto templated_prompt = m_tokenizer.apply_chat_template(history, true); + auto input_ids = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids; + timer.end(); + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params); + } else { + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, prompt, draft_sampling_params)}); + return m_main_pipeline->add_request(request_id, prompt, sampling_params); + } + +} + +std::vector ContinuousBatchingPipeline::Eagle3DecodingImpl::generate( + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids) { + m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + + OPENVINO_ASSERT(!has_non_finished_requests(), + "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use " + "ContinuousBatchingPipeline::add_request"); + + OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); + + ManualTimer generate_timer("speculative_decoding: generate()"); + generate_timer.start(); + + // checks that all requests has the same LoRA adapters property value + for (size_t i = 1; i < sampling_params.size(); ++i) { + OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, + "LoRA adapters value must be the same for all requests"); + } + auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); + auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); + + m_main_eagle_pipeline->set_adapters(sampling_params[0].adapters); + m_draft_eagle_pipeline->set_adapters(sampling_params[0].adapters); + update_eagle_pipeline_params(); + + const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); + + OPENVINO_ASSERT( + !streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || + (sampling_params[0].is_multinomial() && + sampling_params[0].num_return_sequences == 1)), + "Currently eagle streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); + + std::vector main_generations; + ov::Tensor new_input_ids; + for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { + auto new_input_ids = input_ids[request_id]; + OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); + auto main_sampling_params = sampling_params[request_id]; + OPENVINO_ASSERT( + main_sampling_params.assistant_confidence_threshold == 0.f, + "Eagle3 Speculative Decoding pipeline only supports `num_assistant_tokens` " + "as parameter in GenerationConfig and doesn't work with `assistant_confidence_threshold`.\nPlease " + "remove its specification or set it to 0.f."); + if (main_sampling_params.num_assistant_tokens == 0) { + main_sampling_params.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; + } + main_generations.push_back( + m_main_pipeline->add_request(request_id, new_input_ids, main_sampling_params)); + + auto draft_sampling_params = main_sampling_params; + // set the parameters do not stop draft generation without stopping of the same request for main pipeline + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + + // remove first token from input_ids to create draft_input_ids + ov::Tensor draft_input_ids = create_draft_input_ids(new_input_ids); + + std::lock_guard lock(m_draft_generations_mutex); + m_draft_generations.insert( + {request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); + } + auto all_requests = get_awaiting_requests(); + + GenerationHandle& generation = main_generations.at(0); + + streamer_ptr->start(); + + while (has_non_finished_requests()) { + try { + step(); + } catch (...) { + drop_requests(); // remove all requests from pipeline state in case of exception + streamer_ptr->end(); + std::rethrow_exception(std::current_exception()); + } + stream_tokens(streamer_ptr, generation); + } + + // waiting for competion of streaming + streamer_ptr->end(); + + OPENVINO_ASSERT(is_requests_empty(), + "Internal error: current request is supposed to be dropped within step() function as completed"); + + std::vector results; + results.reserve(all_requests.size()); + + m_perf_metrics.draft_model_metrics.raw_metrics = m_draft_pipeline->raw_perf_metrics; + generate_timer.end(); + + for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { + const auto& request = all_requests[request_id]; + auto sampling_params = request->get_sampling_parameters(); + const auto& sequences = request->get_finished_sequences(); + size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); + + EncodedGenerationResult result; + result.m_request_id = request_id; + result.m_generation_ids.resize(num_outputs); + result.m_scores.resize(num_outputs); + result.m_status = request->get_generation_stream()->get_status(); + + for (size_t i = 0; i < num_outputs; ++i) { + const auto& sequence = sequences[i]; + const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) + : sequence->get_cumulative_log_prob(); + const auto& generated_ids = sequence->get_generated_ids(); + + if (sampling_params.echo) { + result.m_generation_ids[i] = request->get_prompt_ids(); + } + std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); + result.m_scores[i] = score; + } + + result.m_status = main_generations[request_id]->get_status(); + + // The same perf metrics for each sequence, only tokenization/detokenization will differ. + m_perf_metrics.raw_metrics.generate_durations.clear(); + m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_timer.get_duration_microsec()); + m_perf_metrics.num_input_tokens = request->get_prompt_len(); + m_perf_metrics.evaluate_statistics(generate_timer.get_start_time()); + + result.perf_metrics = m_perf_metrics; + result.extended_perf_metrics = std::make_shared(m_perf_metrics); + results.push_back(std::move(result)); + } + + OPENVINO_ASSERT(results.size() == input_ids.size()); + return results; +} +} // namespace ov::genai \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp new file mode 100644 index 0000000000..4626fb767f --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp @@ -0,0 +1,103 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "speculative_decoding_impl.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/result.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/manager.hpp" + +namespace ov::genai { +class ContinuousBatchingPipeline::Eagle3DecodingImpl : public ContinuousBatchingPipeline::SpeculativeDecodingImpl { +public: + Eagle3DecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc, const std::vector& hidden_layers_to_abstract); + + std::vector + generate(const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids = std::nullopt) override; + + GenerationHandle add_request(uint64_t request_id, + const ov::Tensor& input_ids, + ov::genai::GenerationConfig sampling_params, + std::optional token_type_ids = std::nullopt) override; + + GenerationHandle add_request(uint64_t request_id, + const std::string& prompt, + ov::genai::GenerationConfig sampling_params) override; + + void set_d2t_for_draft_decoding(std::shared_ptr& d2t_tensor) { + auto eagle_impl = std::static_pointer_cast(m_draft_pipeline); + eagle_impl->set_d2t_for_draft_decoding(d2t_tensor); + }; +protected: + void update_eagle_pipeline_params(); + ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); + std::vector m_hidden_layers_to_abstract; +}; + +using NodePtr = std::shared_ptr; +using namespace ov::op; + +class EagleBaseTransform : public ov::pass::MatcherPass { +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("EagleBaseTransform"); + EagleBaseTransform(std::vector>& results); + + ~EagleBaseTransform() = default; + +private: + bool apply(NodePtr node, std::vector>& results); + size_t applied = 0; + std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node); + std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node, + std::set& visited_nodes); +}; +class EagleInputTransform : public ov::pass::MatcherPass { // eagle3 specific for draft model +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("EagleInputTransform"); + EagleInputTransform(std::vector>& params); + + ~EagleInputTransform() = default; + +private: + bool apply(NodePtr node, std::vector>& params); + size_t applied = 0; +}; +class Eagle3Transform : public ov::pass::MatcherPass { +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("Eagle3Transform"); + Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs); + + ~Eagle3Transform() = default; + +private: + std::vector m_layers; // layers to be abstracted +}; + +class EagleModelTransform : public ov::pass::ModelPass { +public: + EagleModelTransform(const std::vector& layer_ids); + bool run_on_model(const std::shared_ptr& model) override; + +private: + const std::vector m_layer_ids; + std::vector> m_new_results; + std::vector> m_new_parameters; + std::vector> m_hidden_layer_outputs; +}; +} diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index bd51267287..955cffe3eb 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -372,515 +372,4 @@ std::vector ContinuousBatchingPipeline::SpeculativeDecodingI return main_awaiting_requests; } -void share_embedding_weights(std::shared_ptr& main_model, std::shared_ptr& draft_model) { - // extract embedding weight from main model - auto find_embedding_gather = [](const std::shared_ptr& model) - -> std::shared_ptr { - for (const auto& node : model->get_ordered_ops()) { - auto gather = std::dynamic_pointer_cast(node); - if (!gather) continue; - // [vocab, hidden_size] * [batch, seq_len] -> [batch, seq_len, hidden_size] - auto data_node = gather->input_value(0).get_node_shared_ptr(); - auto indices_node = gather->input_value(1).get_node_shared_ptr(); - if (!data_node || !indices_node) continue; - // indices_node should be on parameter path, maybe this is better rule - ov::PartialShape ps = data_node->get_output_partial_shape(0); - if (ps.rank().is_static() && ps.rank().get_length() >= 2) { - if (ps[0].is_static() && ps[0].get_length() > 1000) { // Heuristic: vocab size > 1000 - return gather; - } - } - std::string fname = data_node->get_friendly_name(); - if (fname.find("embed_tokens") != std::string::npos || - fname.find("embedding") != std::string::npos) { - return gather; - } - } - return nullptr; - }; - auto main_gather = find_embedding_gather(main_model); - auto draft_gather = find_embedding_gather(draft_model); - if (!main_gather || !draft_gather) { - return; - } - auto main_weight_node = main_gather->input_value(0).get_node_shared_ptr(); - auto draft_weight_node = draft_gather->input_value(0).get_node_shared_ptr(); - - if (main_weight_node.get() == draft_weight_node.get()) { - return; - } - - try { - draft_weight_node->output(0).replace(main_weight_node->output(0)); - } catch (const std::exception& e) { - std::cerr << "Error: failed to import embedding weights from main model to draft model. Exception: " << e.what() << std::endl; - } catch (...) { - std::cerr << "Error: failed to import embedding weights from main model to draft model due to unknown exception." << std::endl; - } -} - -void extract_hidden_state_generic(std::shared_ptr& model, - const std::vector& hidden_layers_to_abstract) { - ov::pass::Manager pm; - pm.register_pass(hidden_layers_to_abstract); - pm.run_passes(model); -} - -EagleModelTransform::EagleModelTransform(const std::vector& layers) : m_layer_ids(layers) { -} - -bool EagleModelTransform::run_on_model(const std::shared_ptr& model) { - // share the embedding weights from main model to draft model - m_new_parameters.clear(); - m_new_results.clear(); - if (m_layer_ids.size() == 1 && m_layer_ids[0] == -1) { - ov::pass::Manager manager; - manager.set_per_pass_validation(false); - manager.register_pass(m_new_results); - // input transform for draft - // here we apply a trick for the fc layer in draft model - manager.register_pass(m_new_parameters); - manager.run_passes(model); - - model->add_parameters(m_new_parameters); - model->add_results(m_new_results); - return true; - } else { - ov::pass::Manager manager; - manager.set_per_pass_validation(false); - manager.register_pass(m_layer_ids, m_hidden_layer_outputs); - manager.run_passes(model); - - if (!m_hidden_layer_outputs.empty()) { - auto concat = std::make_shared(m_hidden_layer_outputs, -1); - concat->set_friendly_name("eagle3_hidden_states_concat"); - - auto result = std::make_shared(concat); - std::string output_name = "last_hidden_state"; - result->output(0).set_names({output_name}); - result->set_friendly_name(output_name); - model->add_results({result}); - return true; - } - } - - return false; -} - -EagleInputTransform::EagleInputTransform(std::vector>& params) { - register_matcher( - std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), - ([¶ms, this](ov::pass::pattern::Matcher& m) { - auto node = m.get_match_root(); - try { - if (apply(node, params)) { - ++applied; - return true; - } - } catch (...) { - OPENVINO_ASSERT(false, "EagleTransform failed to apply"); - } - return false; - }) - ); -} -bool EagleInputTransform::apply(NodePtr node, std::vector>& params) { - if (ov::is_type(node)) { - auto matmul_node = ov::as_type_ptr(node); - // check the input of matmul node, if it is a node with name "hidden_states", then it's the node we want - auto input_node = matmul_node->get_input_node_shared_ptr(0); - if (!ov::as_type_ptr(input_node)) { - return false; - } - - auto shape = node->get_output_partial_shape(0); - auto internal_hidden_state = std::make_shared(node->get_element_type(), node->get_output_partial_shape(0)); - internal_hidden_state->output(0).set_names({"internal_hidden_states"}); - internal_hidden_state->set_friendly_name("internal_hidden_states"); - // create new eltwise node to add output of MatMul node and internal hidden state input from last cycle of itself - auto new_eltwise = std::make_shared(internal_hidden_state, matmul_node->output(0)); - ov::replace_node(matmul_node, new_eltwise); - params.push_back(internal_hidden_state); - return true; - } -} - -EagleBaseTransform::EagleBaseTransform(std::vector>& results) { - register_matcher( - std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), - ([&results, this](ov::pass::pattern::Matcher& m) { - auto node = m.get_match_root(); - try { - if (apply(node, results)) { - ++applied; - return true; - } - } catch (...) { - OPENVINO_ASSERT(false, "EagleTransform failed to apply"); - } - return false; - }) - ); -} - -std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node, - std::set& visited_nodes) { - if (visited_nodes.count(start_node.get())) { - return nullptr; - } - - visited_nodes.insert(start_node.get()); - - if (ov::is_type(start_node)) { - // check the input nodes of MatMul, if found Gather node, return the gather node, otherwise ,retrun the matmul node - for (size_t i = 0; i < start_node->get_input_size(); ++i) { - auto input_node = start_node->get_input_node_shared_ptr(i); - if (!input_node) continue; - if (ov::as_type_ptr(input_node)) { - return start_node; // return the Add node itself - } - } - } - - for (size_t i = 0; i < start_node->get_input_size(); ++i) { - auto input_node = start_node->get_input_node_shared_ptr(i); - if (!input_node) continue; - - auto result = find_last_residual_node(input_node, visited_nodes); - if (result) { - return result; - } - } - return nullptr; -} - -std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node) { - std::set visited_nodes; - return find_last_residual_node(start_node, visited_nodes); -} - -bool EagleBaseTransform::apply(NodePtr node, std::vector>& results) { - { - // 1. without normalization layer 2. add extra input - if (ov::is_type(node)) { - // we are applying transformation to the last hidden state, eagle2 mode - NodePtr input_node = node->get_input_node_shared_ptr(0); - if (!input_node) { - return false; - } - auto last_residual_node = find_last_residual_node(input_node); - if (!last_residual_node) { - return false; - } - // - auto result = std::make_shared(last_residual_node); - std::string output_name = "last_hidden_state"; - result->output(0).set_names({output_name}); - result->set_friendly_name(output_name); - results.push_back(result); - return true; - } - return false; - } -} - -Eagle3Transform::Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs) : m_layers(layers) { - auto is_target_pattern = [&](const Output& output) { - auto add_node = ov::as_type_ptr(output.get_node_shared_ptr()); - auto add_node_name = add_node->get_friendly_name(); - if (add_node_name.find("self_attn") != std::string::npos) - return false; // Skip self-attention layers - bool layer_matched = false; - for (auto layer_idx : m_layers) { - if (add_node_name.find("layers." + std::to_string(layer_idx) + "/") != std::string::npos) { - layer_matched = true; - break; - } - } - - if (!layer_matched) { - return false; // Skip layers that are not in the specified layers - } - auto input0 = add_node->get_input_node_shared_ptr(1); - if (!input0 || !ov::is_type(input0)) { - return false; - } - auto matmul_node = input0; - auto matmul_input = matmul_node->get_input_node_shared_ptr(0); - if (!matmul_input) { - return false; - } - - bool has_multiply = ov::is_type(matmul_input); // ACT(up) dot gate - return has_multiply; - }; - - auto hidden_layer = ov::pass::pattern::wrap_type(is_target_pattern); - register_matcher(std::make_shared(hidden_layer, "Eagle3Transform::hidden_extraction"), - [&hidden_state_outputs, this](ov::pass::pattern::Matcher& m) { - auto node = m.get_match_root(); - if (ov::is_type(node)) { - hidden_state_outputs.push_back(node->output(0)); - return true; - } - return false; - } - ); -} - -ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, - const ov::genai::ModelDesc& draft_model_desc, - const std::vector& hidden_layers) - : m_hidden_layers_to_abstract(hidden_layers) { - auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); - auto main_model = main_model_desc.model; - auto draft_model = draft_model_desc.model; - - auto main_device = main_model_desc.device; - std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; - - ov::AnyMap draft_properties = - draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; - - // main and draft model can have different tokenizers - // to do: support retokenization: 154103 - Tokenizer main_model_tokenizer = main_model_desc.tokenizer; - Tokenizer draft_model_tokenizer = draft_model_desc.tokenizer; - m_tokenizer = main_model_tokenizer; - // for eagle model, we need to obtain hidden layer state as extra output - // apply transformations needed to run eagle model - // target model: hidden state extraction, draft model: hidden state import , hidden state extraction - // eagle3 specific : dt importing - share_embedding_weights(main_model, draft_model); - extract_hidden_state_generic(main_model, hidden_layers); - extract_hidden_state_generic(draft_model, { -1 }); - - // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode - m_main_pipeline = std::make_shared(main_model, - main_model_tokenizer, - main_model_desc.generation_config, - scheduler_configs.first, - main_device, - main_model_desc.properties, - true); - m_draft_pipeline = std::make_shared(draft_model, - draft_model_tokenizer, - draft_model_desc.generation_config, - scheduler_configs.second, - draft_device, - draft_properties, - false); - m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); - m_perf_metrics.raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; - m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; -} - -ov::Tensor ContinuousBatchingPipeline::EagleDecodingImpl::create_draft_input_ids(const ov::Tensor& original_input_ids) { - auto shape = original_input_ids.get_shape(); - if (shape.size() == 0 || shape.back() <= 1) { - return ov::Tensor(original_input_ids); - } - - size_t original_length = shape.back(); - size_t new_length = original_length - 1; - - ov::Tensor draft_input_ids(ov::element::i64, {1, new_length}); - - const int64_t* src_data = original_input_ids.data(); - int64_t* dst_data = draft_input_ids.data(); - - std::copy(src_data + 1, src_data + original_length, dst_data); - - return draft_input_ids; -} - -void ContinuousBatchingPipeline::EagleDecodingImpl::update_eagle_pipeline_params() { - auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); - auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); - m_main_eagle_pipeline->set_hidden_state_export_needed(true); - m_draft_eagle_pipeline->set_hidden_state_export_needed(true); - m_draft_eagle_pipeline->set_hidden_state_import_needed(true); - m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); - m_draft_eagle_pipeline->set_adjust_factor(m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); -} - -GenerationHandle -ContinuousBatchingPipeline::EagleDecodingImpl::add_request(uint64_t request_id, - const ov::Tensor& input_ids, - ov::genai::GenerationConfig sampling_params, - std::optional token_type_ids) { - std::lock_guard lock(m_draft_generations_mutex); - auto draft_sampling_params = sampling_params; - draft_sampling_params.ignore_eos = true; - draft_sampling_params.stop_strings = {}; - update_eagle_pipeline_params(); - // remove first token from input_ids to create draft_input_ids - ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params, token_type_ids)}); - return m_main_pipeline->add_request(request_id, input_ids, sampling_params, token_type_ids); -} - -GenerationHandle -ContinuousBatchingPipeline::EagleDecodingImpl::add_request(uint64_t request_id, - const std::string& prompt, - ov::genai::GenerationConfig sampling_params) { - std::lock_guard lock(m_draft_generations_mutex); - auto draft_sampling_params = sampling_params; - draft_sampling_params.ignore_eos = true; - draft_sampling_params.stop_strings = {}; - update_eagle_pipeline_params(); - // remove first token from input_ids to create draft_input_ids - if (m_model_input_type == ModelInputType::TOKENS) { - static ManualTimer timer("tokenize"); - timer.start(); - ChatHistory history({{{"role", "user"}, {"content", prompt}}}); - auto templated_prompt = m_tokenizer.apply_chat_template(history, true); - auto input_ids = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids; - timer.end(); - ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); - return m_main_pipeline->add_request(request_id, input_ids, sampling_params); - } else { - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, prompt, draft_sampling_params)}); - return m_main_pipeline->add_request(request_id, prompt, sampling_params); - } - -} - -std::vector ContinuousBatchingPipeline::EagleDecodingImpl::generate( - const std::vector& input_ids, - const std::vector& sampling_params, - const StreamerVariant& streamer, - std::optional> token_type_ids) { - m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); - m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; - - OPENVINO_ASSERT(!has_non_finished_requests(), - "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use " - "ContinuousBatchingPipeline::add_request"); - - OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); - - ManualTimer generate_timer("speculative_decoding: generate()"); - generate_timer.start(); - - // checks that all requests has the same LoRA adapters property value - for (size_t i = 1; i < sampling_params.size(); ++i) { - OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, - "LoRA adapters value must be the same for all requests"); - } - auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); - auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); - - m_main_eagle_pipeline->set_adapters(sampling_params[0].adapters); - m_draft_eagle_pipeline->set_adapters(sampling_params[0].adapters); - update_eagle_pipeline_params(); - - const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); - - OPENVINO_ASSERT( - !streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || - (sampling_params[0].is_multinomial() && - sampling_params[0].num_return_sequences == 1)), - "Currently eagle streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); - - std::vector main_generations; - ov::Tensor new_input_ids; - for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - auto new_input_ids = input_ids[request_id]; - OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); - auto main_sampling_params = sampling_params[request_id]; - OPENVINO_ASSERT( - main_sampling_params.assistant_confidence_threshold == 0.f, - "Eagle3 Speculative Decoding pipeline only supports `num_assistant_tokens` " - "as parameter in GenerationConfig and doesn't work with `assistant_confidence_threshold`.\nPlease " - "remove its specification or set it to 0.f."); - if (main_sampling_params.num_assistant_tokens == 0) { - main_sampling_params.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; - } - main_generations.push_back( - m_main_pipeline->add_request(request_id, new_input_ids, main_sampling_params)); - - auto draft_sampling_params = main_sampling_params; - // set the parameters do not stop draft generation without stopping of the same request for main pipeline - draft_sampling_params.ignore_eos = true; - draft_sampling_params.stop_strings = {}; - - // remove first token from input_ids to create draft_input_ids - ov::Tensor draft_input_ids = create_draft_input_ids(new_input_ids); - - std::lock_guard lock(m_draft_generations_mutex); - m_draft_generations.insert( - {request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); - } - auto all_requests = get_awaiting_requests(); - - GenerationHandle& generation = main_generations.at(0); - - streamer_ptr->start(); - - while (has_non_finished_requests()) { - try { - step(); - } catch (...) { - drop_requests(); // remove all requests from pipeline state in case of exception - streamer_ptr->end(); - std::rethrow_exception(std::current_exception()); - } - stream_tokens(streamer_ptr, generation); - } - - // waiting for competion of streaming - streamer_ptr->end(); - - OPENVINO_ASSERT(is_requests_empty(), - "Internal error: current request is supposed to be dropped within step() function as completed"); - - std::vector results; - results.reserve(all_requests.size()); - - m_perf_metrics.draft_model_metrics.raw_metrics = m_draft_pipeline->raw_perf_metrics; - generate_timer.end(); - - for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { - const auto& request = all_requests[request_id]; - auto sampling_params = request->get_sampling_parameters(); - const auto& sequences = request->get_finished_sequences(); - size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); - - EncodedGenerationResult result; - result.m_request_id = request_id; - result.m_generation_ids.resize(num_outputs); - result.m_scores.resize(num_outputs); - result.m_status = request->get_generation_stream()->get_status(); - - for (size_t i = 0; i < num_outputs; ++i) { - const auto& sequence = sequences[i]; - const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) - : sequence->get_cumulative_log_prob(); - const auto& generated_ids = sequence->get_generated_ids(); - - if (sampling_params.echo) { - result.m_generation_ids[i] = request->get_prompt_ids(); - } - std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); - result.m_scores[i] = score; - } - - result.m_status = main_generations[request_id]->get_status(); - - // The same perf metrics for each sequence, only tokenization/detokenization will differ. - m_perf_metrics.raw_metrics.generate_durations.clear(); - m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_timer.get_duration_microsec()); - m_perf_metrics.num_input_tokens = request->get_prompt_len(); - m_perf_metrics.evaluate_statistics(generate_timer.get_start_time()); - - result.perf_metrics = m_perf_metrics; - result.extended_perf_metrics = std::make_shared(m_perf_metrics); - results.push_back(std::move(result)); - } - - OPENVINO_ASSERT(results.size() == input_ids.size()); - return results; -} } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 810876553b..162fcace1e 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -8,18 +8,6 @@ #include "speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp" #include "speculative_decoding/speculative_decoding_metrics.hpp" #include "openvino/genai/speculative_decoding/perf_metrics.hpp" -#include "openvino/op/add.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/result.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/convert.hpp" -#include "openvino/op/gather.hpp" -#include "openvino/op/matmul.hpp" -#include "openvino/op/concat.hpp" -#include "openvino/pass/pattern/matcher.hpp" -#include "openvino/pass/pattern/op/wrap_type.hpp" -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/manager.hpp" #include "utils.hpp" namespace ov::genai { @@ -64,86 +52,4 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat SpeculativeDecodingMetrics get_speculative_decoding_metrics(); }; -class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingPipeline::SpeculativeDecodingImpl { -public: - EagleDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc, const std::vector& hidden_layers_to_abstract); - - std::vector - generate(const std::vector& input_ids, - const std::vector& sampling_params, - const StreamerVariant& streamer, - std::optional> token_type_ids = std::nullopt) override; - - GenerationHandle add_request(uint64_t request_id, - const ov::Tensor& input_ids, - ov::genai::GenerationConfig sampling_params, - std::optional token_type_ids = std::nullopt) override; - - GenerationHandle add_request(uint64_t request_id, - const std::string& prompt, - ov::genai::GenerationConfig sampling_params) override; - - void set_d2t_for_draft_decoding(std::shared_ptr& d2t_tensor) { - auto eagle_impl = std::dynamic_pointer_cast(m_draft_pipeline); - eagle_impl->set_d2t_for_draft_decoding(d2t_tensor); - }; -protected: - void update_eagle_pipeline_params(); - ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); - std::vector m_hidden_layers_to_abstract; -}; - -using NodePtr = std::shared_ptr; -using namespace ov::op; - -class EagleBaseTransform : public ov::pass::MatcherPass { -public: - using NodePtr = std::shared_ptr; - OPENVINO_MATCHER_PASS_RTTI("EagleBaseTransform"); - EagleBaseTransform(std::vector>& results); - - ~EagleBaseTransform() = default; - -private: - bool apply(NodePtr node, std::vector>& results); - size_t applied = 0; - std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node); - std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node, - std::set& visited_nodes); -}; -class EagleInputTransform : public ov::pass::MatcherPass { // eagle3 specific for draft model -public: - using NodePtr = std::shared_ptr; - OPENVINO_MATCHER_PASS_RTTI("EagleInputTransform"); - EagleInputTransform(std::vector>& params); - - ~EagleInputTransform() = default; - -private: - bool apply(NodePtr node, std::vector>& params); - size_t applied = 0; -}; -class Eagle3Transform : public ov::pass::MatcherPass { -public: - using NodePtr = std::shared_ptr; - OPENVINO_MATCHER_PASS_RTTI("Eagle3Transform"); - Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs); - - ~Eagle3Transform() = default; - -private: - std::vector m_layers; // layers to be abstracted -}; - -class EagleModelTransform : public ov::pass::ModelPass { -public: - EagleModelTransform(const std::vector& layer_ids); - bool run_on_model(const std::shared_ptr& model) override; - -private: - const std::vector m_layer_ids; - std::vector> m_new_results; - std::vector> m_new_parameters; - std::vector> m_hidden_layer_outputs; -}; -} +} // namespace ov::genai \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/update_request_structs.hpp b/src/cpp/src/speculative_decoding/update_request_structs.hpp index 2f62ebbfa4..4426372507 100644 --- a/src/cpp/src/speculative_decoding/update_request_structs.hpp +++ b/src/cpp/src/speculative_decoding/update_request_structs.hpp @@ -11,7 +11,7 @@ struct GeneratedSequence { std::vector token_ids; std::vector log_probs; // Stores the hidden states tensor associated with the generated sequence. - // This field is primarily used for the "eagle speculative" decoding algorithm, + // This field is used for the "eagle speculative" decoding algorithm, // where hidden states are required to efficiently validate and extend speculative tokens. // If not using eagle speculative decoding, this field may remain empty. ov::Tensor hidden_states; From 0931c09e381e01ae8bbbe993aacd613a5740a516 Mon Sep 17 00:00:00 2001 From: fishbell Date: Tue, 21 Oct 2025 23:05:27 +0800 Subject: [PATCH 31/43] reuse common codes between spec decode and eagle3 decode Signed-off-by: fishbell --- .../speculative_decoding_eagle3_impl.cpp | 177 +++++------------- .../speculative_decoding_eagle3_impl.hpp | 10 + .../speculative_decoding_impl.cpp | 138 +++----------- .../speculative_decoding_impl.hpp | 144 ++++++++++++++ 4 files changed, 226 insertions(+), 243 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp index 2a6f26b134..11cdd1a1b3 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -326,13 +326,16 @@ ov::Tensor ContinuousBatchingPipeline::Eagle3DecodingImpl::create_draft_input_id } void ContinuousBatchingPipeline::Eagle3DecodingImpl::update_eagle_pipeline_params() { - auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); - auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); - m_main_eagle_pipeline->set_hidden_state_export_needed(true); - m_draft_eagle_pipeline->set_hidden_state_export_needed(true); - m_draft_eagle_pipeline->set_hidden_state_import_needed(true); - m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); - m_draft_eagle_pipeline->set_adjust_factor(m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); + std::call_once(m_eagle_params_once, [this]() { + auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); + auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); + m_main_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_import_needed(true); + m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); + m_draft_eagle_pipeline->set_adjust_factor( + m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); + }); } GenerationHandle @@ -383,135 +386,39 @@ std::vector ContinuousBatchingPipeline::Eagle3DecodingI const std::vector& sampling_params, const StreamerVariant& streamer, std::optional> token_type_ids) { - m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); - m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; - - OPENVINO_ASSERT(!has_non_finished_requests(), - "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use " - "ContinuousBatchingPipeline::add_request"); - - OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); - - ManualTimer generate_timer("speculative_decoding: generate()"); - generate_timer.start(); - - // checks that all requests has the same LoRA adapters property value - for (size_t i = 1; i < sampling_params.size(); ++i) { - OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, - "LoRA adapters value must be the same for all requests"); - } - auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); - auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); - - m_main_eagle_pipeline->set_adapters(sampling_params[0].adapters); - m_draft_eagle_pipeline->set_adapters(sampling_params[0].adapters); - update_eagle_pipeline_params(); - - const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); - - OPENVINO_ASSERT( - !streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || - (sampling_params[0].is_multinomial() && - sampling_params[0].num_return_sequences == 1)), - "Currently eagle streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); - - std::vector main_generations; - ov::Tensor new_input_ids; - for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - auto new_input_ids = input_ids[request_id]; - OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); - auto main_sampling_params = sampling_params[request_id]; - OPENVINO_ASSERT( - main_sampling_params.assistant_confidence_threshold == 0.f, - "Eagle3 Speculative Decoding pipeline only supports `num_assistant_tokens` " - "as parameter in GenerationConfig and doesn't work with `assistant_confidence_threshold`.\nPlease " - "remove its specification or set it to 0.f."); - if (main_sampling_params.num_assistant_tokens == 0) { - main_sampling_params.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; - } - main_generations.push_back( - m_main_pipeline->add_request(request_id, new_input_ids, main_sampling_params)); - - auto draft_sampling_params = main_sampling_params; - // set the parameters do not stop draft generation without stopping of the same request for main pipeline - draft_sampling_params.ignore_eos = true; - draft_sampling_params.stop_strings = {}; - - // remove first token from input_ids to create draft_input_ids - ov::Tensor draft_input_ids = create_draft_input_ids(new_input_ids); - - std::lock_guard lock(m_draft_generations_mutex); - m_draft_generations.insert( - {request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); - } - auto all_requests = get_awaiting_requests(); - - GenerationHandle& generation = main_generations.at(0); - - streamer_ptr->start(); - - while (has_non_finished_requests()) { - try { - step(); - } catch (...) { - drop_requests(); // remove all requests from pipeline state in case of exception - streamer_ptr->end(); - std::rethrow_exception(std::current_exception()); - } - stream_tokens(streamer_ptr, generation); - } - - // waiting for competion of streaming - streamer_ptr->end(); - - OPENVINO_ASSERT(is_requests_empty(), - "Internal error: current request is supposed to be dropped within step() function as completed"); - - std::vector results; - results.reserve(all_requests.size()); - - m_perf_metrics.draft_model_metrics.raw_metrics = m_draft_pipeline->raw_perf_metrics; - generate_timer.end(); - - for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { - const auto& request = all_requests[request_id]; - auto sampling_params = request->get_sampling_parameters(); - const auto& sequences = request->get_finished_sequences(); - size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); - - EncodedGenerationResult result; - result.m_request_id = request_id; - result.m_generation_ids.resize(num_outputs); - result.m_scores.resize(num_outputs); - result.m_status = request->get_generation_stream()->get_status(); - - for (size_t i = 0; i < num_outputs; ++i) { - const auto& sequence = sequences[i]; - const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) - : sequence->get_cumulative_log_prob(); - const auto& generated_ids = sequence->get_generated_ids(); - - if (sampling_params.echo) { - result.m_generation_ids[i] = request->get_prompt_ids(); - } - std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); - result.m_scores[i] = score; + GenerateStrategy strategy; + strategy.prepare_request = [this](size_t, + const ov::Tensor& in_ids, + GenerationConfig& main_cfg, + GenerationConfig& draft_cfg, + ov::Tensor& main_in, + ov::Tensor& draft_in) { + OPENVINO_ASSERT(main_cfg.assistant_confidence_threshold == 0.f, + "Eagle3 only supports num_assistant_tokens (assistant_confidence_threshold must be 0.f)"); + if (main_cfg.num_assistant_tokens == 0) { + main_cfg.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; } + draft_cfg.ignore_eos = true; + draft_cfg.stop_strings = {}; + main_in = in_ids; + draft_in = create_draft_input_ids(in_ids); + }; + strategy.pre_loop = [this](){ update_eagle_pipeline_params(); }; + strategy.check_streaming = [](const std::shared_ptr& streamer_ptr, + const std::vector& input_ids, + const std::vector& sampling_params) { + OPENVINO_ASSERT(!streamer_ptr->has_callback() || + (input_ids.size() == 1 && + (sampling_params[0].is_greedy_decoding())), + "Eagle3 streaming only supports batch size=1 with greedy"); + }; + strategy.start_timer = [](){ + return std::chrono::steady_clock::now(); + }; + strategy.stop_timer = [](TimePoint start){ + return PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start); + }; - result.m_status = main_generations[request_id]->get_status(); - - // The same perf metrics for each sequence, only tokenization/detokenization will differ. - m_perf_metrics.raw_metrics.generate_durations.clear(); - m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_timer.get_duration_microsec()); - m_perf_metrics.num_input_tokens = request->get_prompt_len(); - m_perf_metrics.evaluate_statistics(generate_timer.get_start_time()); - - result.perf_metrics = m_perf_metrics; - result.extended_perf_metrics = std::make_shared(m_perf_metrics); - results.push_back(std::move(result)); - } - - OPENVINO_ASSERT(results.size() == input_ids.size()); - return results; + return generate_common(this, input_ids, sampling_params, streamer, token_type_ids, strategy); } } // namespace ov::genai \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp index 4626fb767f..3b65335bba 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp @@ -20,6 +20,14 @@ namespace ov::genai { class ContinuousBatchingPipeline::Eagle3DecodingImpl : public ContinuousBatchingPipeline::SpeculativeDecodingImpl { public: + template + friend std::vector generate_common( + Impl*, + const std::vector&, + const std::vector&, + const StreamerVariant&, + std::optional>, + GenerateStrategy&); Eagle3DecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc, const std::vector& hidden_layers_to_abstract); std::vector @@ -45,6 +53,8 @@ class ContinuousBatchingPipeline::Eagle3DecodingImpl : public ContinuousBatching void update_eagle_pipeline_params(); ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); std::vector m_hidden_layers_to_abstract; +private: + std::once_flag m_eagle_params_once; }; using NodePtr = std::shared_ptr; diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 955cffe3eb..76c500a9a9 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -237,116 +237,38 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< const std::vector& sampling_params, const StreamerVariant& streamer, std::optional> token_type_ids) { - OPENVINO_ASSERT(!token_type_ids.has_value()); - m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); - m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; - - OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); - OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); - - const auto generate_start = std::chrono::steady_clock::now(); - - // checks that all requests has the same LoRA adapters property value - for (size_t i = 1; i < sampling_params.size(); ++i) { - OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, - "LoRA adapters value must be the same for all requests"); - } - m_main_pipeline->set_adapters(sampling_params[0].adapters); - m_draft_pipeline->set_adapters(sampling_params[0].adapters); - - const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); - - OPENVINO_ASSERT(!streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), - "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); - - std::vector main_generations; - for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); - auto main_sampling_params = sampling_params[request_id]; - if (main_sampling_params.assistant_confidence_threshold == 0.f) { - if (main_sampling_params.num_assistant_tokens == 0) { - main_sampling_params.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; - } - } - main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], main_sampling_params)); - - auto draft_sampling_params = main_sampling_params; - // set the parameters do not stop draft generation without stopping of the same request for main pipeline - draft_sampling_params.ignore_eos = true; - draft_sampling_params.stop_strings = {}; - std::lock_guard lock(m_draft_generations_mutex); - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params)}); - } - auto all_requests = get_awaiting_requests(); - - GenerationHandle& generation = main_generations.at(0); - - streamer_ptr->start(); - - while (has_non_finished_requests()) { - try { - step(); - } catch (...) { - drop_requests(); // remove all requests from pipeline state in case of exception - streamer_ptr->end(); - std::rethrow_exception(std::current_exception()); - } - stream_tokens(streamer_ptr, generation); - } - - // waiting for competion of streaming - streamer_ptr->end(); - - OPENVINO_ASSERT(is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); - - std::vector results; - results.reserve(all_requests.size()); - - m_perf_metrics.draft_model_metrics.raw_metrics = m_draft_pipeline->raw_perf_metrics; - - const auto generate_end = std::chrono::steady_clock::now(); - const auto generate_duration = PerfMetrics::get_microsec(generate_end - generate_start); - - for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { - const auto& request = all_requests[request_id]; - auto sampling_params = request->get_sampling_parameters(); - const auto& sequences = request->get_finished_sequences(); - size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); - - EncodedGenerationResult result; - result.m_request_id = request_id; - result.m_generation_ids.resize(num_outputs); - result.m_scores.resize(num_outputs); - result.m_status = request->get_generation_stream()->get_status(); - - for (size_t i = 0; i < num_outputs; ++i) { - const auto & sequence = sequences[i]; - const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob(); - const auto & generated_ids = sequence->get_generated_ids(); - - if (sampling_params.echo) { - result.m_generation_ids[i] = request->get_prompt_ids(); + GenerateStrategy strategy; + strategy.prepare_request = [this](size_t, + const ov::Tensor& in_ids, + GenerationConfig& main_cfg, + GenerationConfig& draft_cfg, + ov::Tensor& main_in, + ov::Tensor& draft_in) { + if (main_cfg.assistant_confidence_threshold == 0.f) { + if (main_cfg.num_assistant_tokens == 0) { + main_cfg.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; } - std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); - result.m_scores[i] = score; } - - result.m_status = main_generations[request_id]->get_status(); - - // The same perf metrics for each sequence, only tokenization/detokenization will differ. - m_perf_metrics.raw_metrics.generate_durations.clear(); - m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_duration); - m_perf_metrics.num_input_tokens = request->get_prompt_len(); - m_perf_metrics.evaluate_statistics(generate_start); - - result.perf_metrics = m_perf_metrics; - result.extended_perf_metrics = std::make_shared(m_perf_metrics); - results.push_back(std::move(result)); - } - - OPENVINO_ASSERT(results.size() == input_ids.size()); - - return results; + draft_cfg.ignore_eos = true; + draft_cfg.stop_strings = {}; + main_in = in_ids; + draft_in = in_ids; + }; + strategy.pre_loop = nullptr; + strategy.check_streaming = [](const std::shared_ptr& streamer_ptr, + const std::vector& input_ids, + const std::vector& sampling_params) { + OPENVINO_ASSERT(!streamer_ptr->has_callback() || + (input_ids.size() == 1 && + (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial())), + "Streaming only supports batch size=1 with greedy/multinomial"); + }; + strategy.start_timer = [](){ return std::chrono::steady_clock::now(); }; + strategy.stop_timer = [](TimePoint start){ + return PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start); + }; + + return generate_common(this, input_ids, sampling_params, streamer, token_type_ids, strategy); } SpeculativeDecodingMetrics diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 162fcace1e..e3090d9a7d 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -11,6 +11,131 @@ #include "utils.hpp" namespace ov::genai { +struct GenerateStrategy { + std::function prepare_request; + std::function pre_loop; + std::function&, + const std::vector&, + const std::vector&)> check_streaming; + std::function start_timer; + std::function stop_timer; +}; + +template +std::vector generate_common( + Impl* self, + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids, + GenerateStrategy& strategy) { + + OPENVINO_ASSERT(!token_type_ids.has_value()); + self->perf_metrics() = ov::genai::SDPerModelsPerfMetrics(); + self->draft_pipeline()->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + + OPENVINO_ASSERT(!self->has_non_finished_requests(), + "Generate cannot be called while ContinuousBatchingPipeline is already running"); + OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); + + auto t_start = strategy.start_timer(); + + for (size_t i = 1; i < sampling_params.size(); ++i) { + OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, + "LoRA adapters must be same for all requests"); + } + self->main_pipeline()->set_adapters(sampling_params[0].adapters); + self->draft_pipeline()->set_adapters(sampling_params[0].adapters); + + auto streamer_ptr = std::make_shared(streamer, self->tokenizer()); + + strategy.check_streaming(streamer_ptr, input_ids, sampling_params); + if (strategy.pre_loop) strategy.pre_loop(); + + std::vector main_generations; + { + std::lock_guard lock(self->draft_generations_mutex()); + for (size_t rid = 0; rid < input_ids.size(); ++rid) { + GenerationConfig main_cfg = sampling_params[rid]; + GenerationConfig draft_cfg = main_cfg; + ov::Tensor main_in, draft_in; + strategy.prepare_request(rid, input_ids[rid], + main_cfg, draft_cfg, + main_in, draft_in); + main_generations.push_back(self->main_pipeline()->add_request(rid, main_in, main_cfg)); + self->draft_generations().insert({rid, + self->draft_pipeline()->add_request(rid, draft_in, draft_cfg)}); + } + } + + auto all_requests = self->get_awaiting_requests(); + GenerationHandle& generation = main_generations.at(0); + + streamer_ptr->start(); + while (self->has_non_finished_requests()) { + try { + self->step(); + } catch (...) { + self->drop_requests(); + streamer_ptr->end(); + std::rethrow_exception(std::current_exception()); + } + self->stream_tokens(streamer_ptr, generation); + } + streamer_ptr->end(); + + OPENVINO_ASSERT(self->is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); + + self->perf_metrics().draft_model_metrics.raw_metrics = self->draft_pipeline()->raw_perf_metrics; + uint64_t generate_duration_us = strategy.stop_timer(t_start); + + std::vector results; + results.reserve(all_requests.size()); + + for (size_t rid = 0; rid < all_requests.size(); ++rid) { + const auto& request = all_requests[rid]; + auto cfg = request->get_sampling_parameters(); + const auto& seqs = request->get_finished_sequences(); + size_t num_out = std::min(cfg.num_return_sequences, seqs.size()); + + EncodedGenerationResult result; + result.m_request_id = rid; + result.m_generation_ids.resize(num_out); + result.m_scores.resize(num_out); + result.m_status = main_generations[rid]->get_status(); + + for (size_t i = 0; i < num_out; ++i) { + const auto& seq = seqs[i]; + float score = cfg.is_beam_search() ? + seq->get_beam_search_score(cfg) : + seq->get_cumulative_log_prob(); + const auto& gen_ids = seq->get_generated_ids(); + if (cfg.echo) { + result.m_generation_ids[i] = request->get_prompt_ids(); + } + std::copy(gen_ids.begin(), gen_ids.end(), + std::back_inserter(result.m_generation_ids[i])); + result.m_scores[i] = score; + } + + self->perf_metrics().raw_metrics.generate_durations.clear(); + self->perf_metrics().raw_metrics.generate_durations.emplace_back(generate_duration_us); + self->perf_metrics().num_input_tokens = request->get_prompt_len(); + self->perf_metrics().evaluate_statistics(t_start); + + result.perf_metrics = self->perf_metrics(); + result.extended_perf_metrics = std::make_shared(self->perf_metrics()); + results.push_back(std::move(result)); + } + + OPENVINO_ASSERT(results.size() == input_ids.size()); + return results; +} class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline { protected: @@ -28,6 +153,15 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat std::vector get_awaiting_requests(); std::pair init_speculative_models(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); public: + template + friend std::vector generate_common( + Impl* self, + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids, + GenerateStrategy& strategy); + SpeculativeDecodingImpl() = default; SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); @@ -50,6 +184,16 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat std::optional> token_type_ids = std::nullopt) override; SpeculativeDecodingMetrics get_speculative_decoding_metrics(); + SDPerModelsPerfMetrics& perf_metrics() { return m_perf_metrics; } + SDPerModelsPerfMetrics const& perf_metrics() const { return m_perf_metrics; } + std::shared_ptr& draft_pipeline() { return m_draft_pipeline; } + std::shared_ptr& main_pipeline() { return m_main_pipeline; } + + Tokenizer& tokenizer() { return m_tokenizer; } + const Tokenizer& tokenizer() const { return m_tokenizer; } + + std::mutex& draft_generations_mutex() { return m_draft_generations_mutex; } + std::map& draft_generations() { return m_draft_generations; } }; } // namespace ov::genai \ No newline at end of file From 3c977185f3b39bc002762befc23671634e308eb9 Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 22 Oct 2025 01:35:36 +0800 Subject: [PATCH 32/43] add missing default Signed-off-by: fishbell --- .../speculative_decoding/speculative_decoding_eagle3_impl.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp index 11cdd1a1b3..15731bca29 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -397,6 +397,7 @@ std::vector ContinuousBatchingPipeline::Eagle3DecodingI "Eagle3 only supports num_assistant_tokens (assistant_confidence_threshold must be 0.f)"); if (main_cfg.num_assistant_tokens == 0) { main_cfg.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; + draft_cfg.num_assistant_tokens = main_cfg.num_assistant_tokens; } draft_cfg.ignore_eos = true; draft_cfg.stop_strings = {}; From 5f16e65afcee05a01e38da050240a415984248ff Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 22 Oct 2025 02:09:26 +0800 Subject: [PATCH 33/43] enable test interface first Signed-off-by: fishbell --- tests/python_tests/utils/hugging_face.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/python_tests/utils/hugging_face.py b/tests/python_tests/utils/hugging_face.py index 6d96f0e1c4..fbcc28376c 100644 --- a/tests/python_tests/utils/hugging_face.py +++ b/tests/python_tests/utils/hugging_face.py @@ -166,9 +166,14 @@ def run_hugging_face( # download HF model or read converted model def get_huggingface_models(model_id: str | Path, model_class: Type[OVModel], local_files_only=False): - hf_tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, local_files_only=local_files_only)) - opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, trust_remote_code=isinstance(model_id, str), ov_config=get_default_llm_properties(), local_files_only=local_files_only)) - return opt_model, hf_tokenizer + if (not "eagle3" in str(model_id).lower()): + hf_tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, local_files_only=local_files_only)) + opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, trust_remote_code=isinstance(model_id, str), ov_config=get_default_llm_properties(), local_files_only=local_files_only)) + return opt_model, hf_tokenizer + else : + hf_tokenizer = None + opt_model = retry_request(lambda: model_class.from_pretrained(model_id, eagle3 = True, export=isinstance(model_id, str), compile=False, load_in_8bit=False, trust_remote_code=isinstance(model_id, str), ov_config=get_default_llm_properties(), local_files_only=local_files_only)) + return opt_model, hf_tokenizer def convert_and_save_tokenizer(hf_tokenizer : AutoTokenizer, @@ -192,9 +197,10 @@ def convert_models(opt_model : OVModelForCausalLM, opt_model.config.save_pretrained(models_path) # to store tokenizer config jsons with special tokens - hf_tokenizer.save_pretrained(models_path) - # convert tokenizers as well - convert_and_save_tokenizer(hf_tokenizer, models_path, **tokenizer_kwargs) + if hf_tokenizer: + hf_tokenizer.save_pretrained(models_path) + # convert tokenizers as well + convert_and_save_tokenizer(hf_tokenizer, models_path, **tokenizer_kwargs) def download_and_convert_model(model_id: str, **tokenizer_kwargs): From bb8a39de5858752c69844d4d5f7524094f93545d Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 22 Oct 2025 19:29:40 +0800 Subject: [PATCH 34/43] apply review comments Signed-off-by: fishbell --- .../speculative_decoding_eagle3_impl.cpp | 45 +++++++------------ .../speculative_decoding_eagle3_impl.hpp | 2 - .../speculative_decoding_impl.cpp | 1 - .../speculative_decoding_impl.hpp | 2 - .../python_tests/test_continuous_batching.py | 2 +- 5 files changed, 18 insertions(+), 34 deletions(-) diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp index 15731bca29..d78c52323b 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -304,6 +304,9 @@ ContinuousBatchingPipeline::Eagle3DecodingImpl::Eagle3DecodingImpl(const ov::gen m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); m_perf_metrics.raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + + // specific params update for eagle pipeline + update_eagle_pipeline_params(); } ov::Tensor ContinuousBatchingPipeline::Eagle3DecodingImpl::create_draft_input_ids(const ov::Tensor& original_input_ids) { @@ -326,16 +329,14 @@ ov::Tensor ContinuousBatchingPipeline::Eagle3DecodingImpl::create_draft_input_id } void ContinuousBatchingPipeline::Eagle3DecodingImpl::update_eagle_pipeline_params() { - std::call_once(m_eagle_params_once, [this]() { - auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); - auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); - m_main_eagle_pipeline->set_hidden_state_export_needed(true); - m_draft_eagle_pipeline->set_hidden_state_export_needed(true); - m_draft_eagle_pipeline->set_hidden_state_import_needed(true); - m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); - m_draft_eagle_pipeline->set_adjust_factor( - m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); - }); + auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); + auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); + m_main_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_import_needed(true); + m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); + m_draft_eagle_pipeline->set_adjust_factor( + m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); } GenerationHandle @@ -347,7 +348,6 @@ ContinuousBatchingPipeline::Eagle3DecodingImpl::add_request(uint64_t request_id, auto draft_sampling_params = sampling_params; draft_sampling_params.ignore_eos = true; draft_sampling_params.stop_strings = {}; - update_eagle_pipeline_params(); // remove first token from input_ids to create draft_input_ids ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params, token_type_ids)}); @@ -362,23 +362,12 @@ ContinuousBatchingPipeline::Eagle3DecodingImpl::add_request(uint64_t request_id, auto draft_sampling_params = sampling_params; draft_sampling_params.ignore_eos = true; draft_sampling_params.stop_strings = {}; - update_eagle_pipeline_params(); // remove first token from input_ids to create draft_input_ids - if (m_model_input_type == ModelInputType::TOKENS) { - static ManualTimer timer("tokenize"); - timer.start(); - ChatHistory history({{{"role", "user"}, {"content", prompt}}}); - auto templated_prompt = m_tokenizer.apply_chat_template(history, true); - auto input_ids = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids; - timer.end(); - ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); - return m_main_pipeline->add_request(request_id, input_ids, sampling_params); - } else { - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, prompt, draft_sampling_params)}); - return m_main_pipeline->add_request(request_id, prompt, sampling_params); - } - + // add_special_tokens is false for better compress rate + auto input_ids = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)).input_ids; + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params); } std::vector ContinuousBatchingPipeline::Eagle3DecodingImpl::generate( @@ -404,7 +393,7 @@ std::vector ContinuousBatchingPipeline::Eagle3DecodingI main_in = in_ids; draft_in = create_draft_input_ids(in_ids); }; - strategy.pre_loop = [this](){ update_eagle_pipeline_params(); }; + strategy.check_streaming = [](const std::shared_ptr& streamer_ptr, const std::vector& input_ids, const std::vector& sampling_params) { diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp index 3b65335bba..4988586db2 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp @@ -53,8 +53,6 @@ class ContinuousBatchingPipeline::Eagle3DecodingImpl : public ContinuousBatching void update_eagle_pipeline_params(); ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); std::vector m_hidden_layers_to_abstract; -private: - std::once_flag m_eagle_params_once; }; using NodePtr = std::shared_ptr; diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 76c500a9a9..d776557aec 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -254,7 +254,6 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< main_in = in_ids; draft_in = in_ids; }; - strategy.pre_loop = nullptr; strategy.check_streaming = [](const std::shared_ptr& streamer_ptr, const std::vector& input_ids, const std::vector& sampling_params) { diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index e3090d9a7d..d6b8c353a1 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -18,7 +18,6 @@ struct GenerateStrategy { GenerationConfig& draft_cfg, ov::Tensor& main_in, ov::Tensor& draft_in)> prepare_request; - std::function pre_loop; std::function&, const std::vector&, const std::vector&)> check_streaming; @@ -55,7 +54,6 @@ std::vector generate_common( auto streamer_ptr = std::make_shared(streamer, self->tokenizer()); strategy.check_streaming(streamer_ptr, input_ids, sampling_params); - if (strategy.pre_loop) strategy.pre_loop(); std::vector main_generations; { diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index 2c5f0c35dc..a2a9efc1a2 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -527,4 +527,4 @@ def test_speculative_decoding_extended_perf_metrics(pipeline_type): assert mean_gen_duration > 0 and mean_gen_duration < total_time assert std_gen_duration == 0 else: - assert extended_perf_metrics is None \ No newline at end of file + assert extended_perf_metrics is None From 428c25e0f355642df9ae618fbdb756b751ae8276 Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 22 Oct 2025 19:37:43 +0800 Subject: [PATCH 35/43] enable test for local val first Signed-off-by: fishbell --- tests/python_tests/test_eagle3.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/python_tests/test_eagle3.py b/tests/python_tests/test_eagle3.py index 1c5bcff7d8..9241e9d28c 100644 --- a/tests/python_tests/test_eagle3.py +++ b/tests/python_tests/test_eagle3.py @@ -9,11 +9,15 @@ from utils.ov_genai_pipelines import create_ov_pipeline, PipelineType, convert_decoded_results_to_generation_result eagle_models_and_input = [ - ("Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3", "Why is the Sun yellow?")] + ("Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3", """Code: +def add(a, b): + return a + b + +Question: Can you please add 2 and 3 +A:""")] devices = [ - ('CPU', 'CPU'), - ('GPU', 'GPU') + ('CPU', 'CPU') ] @pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) @pytest.mark.parametrize("main_device,draft_device", devices) @@ -30,7 +34,6 @@ def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_dev # Run reference HF model: ov_generation_config = ov_genai.GenerationConfig(max_new_tokens=20) - main_hf_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) ref_gen_results = run_hugging_face(main_opt_model, main_hf_tokenizer, [prompt], ov_generation_config) # Run OpenVINO GenAI pipeline: @@ -63,7 +66,6 @@ def test_eagle3_sd_extended_perf_metrics(main_model, main_device, draft_model, d assert not extended_perf_metrics is None assert not extended_perf_metrics.main_model_metrics is None assert not extended_perf_metrics.draft_model_metrics is None - assert extended_perf_metrics.get_num_accepted_tokens() > 0 num_generated_tokens_main = extended_perf_metrics.main_model_metrics.get_num_generated_tokens() From efd854bcfe1d5c13a19f1436d368cd0b1411ad0f Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 22 Oct 2025 21:35:31 +0800 Subject: [PATCH 36/43] use warning logger instead of cout Signed-off-by: fishbell --- src/cpp/src/continuous_batching/pipeline.cpp | 30 ++----------------- .../speculative_decoding_eagle3_impl.cpp | 11 +++++++ 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index 686771e5c0..b8028957df 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -105,16 +105,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); - // Eagle speculative decoding does not support dynamic_split_fuse mode - // because it requires hidden state interaction from main model to draft model - // to be implemented future - SchedulerConfig scheduler_config_copy = scheduler_config; - if (scheduler_config.dynamic_split_fuse) { - std::cout << "Note: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; - scheduler_config_copy.dynamic_split_fuse = false; - // Use scheduler_config_copy in subsequent code if modification is needed - } - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config_copy, generation_config); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); // parse d2t from safe tensors if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { @@ -167,13 +158,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( // Eagle speculative decoding does not support dynamic_split_fuse mode // because it requires hidden state interaction from main model to draft model // to be implemented future - SchedulerConfig scheduler_config_copy = scheduler_config; - if (scheduler_config.dynamic_split_fuse) { - std::cout << "Note: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; - scheduler_config_copy.dynamic_split_fuse = false; - // Use scheduler_config_copy in subsequent code if modification is needed - } - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config_copy, generation_config); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); // parse d2t from safe tensors if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { @@ -228,16 +213,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model, generation_config); } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); - // Eagle speculative decoding does not support dynamic_split_fuse mode - // because it requires hidden state interaction from main model to draft model - // to be implemented future - SchedulerConfig scheduler_config_copy = scheduler_config; - if (scheduler_config.dynamic_split_fuse) { - std::cout << "Note: disable dynamic split fuse for eagle3 speculative decoding" << std::endl; - scheduler_config_copy.dynamic_split_fuse = false; - // Use scheduler_config_copy in subsequent code if modification is needed - } - auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config_copy, generation_config); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); // parse d2t from safe tensors if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp index d78c52323b..ffd8808b80 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -1,6 +1,7 @@ // Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #include "speculative_decoding_eagle3_impl.hpp" +#include "logger.hpp" namespace ov::genai { void share_embedding_weights(std::shared_ptr& main_model, std::shared_ptr& draft_model) { @@ -264,6 +265,16 @@ ContinuousBatchingPipeline::Eagle3DecodingImpl::Eagle3DecodingImpl(const ov::gen const std::vector& hidden_layers) : m_hidden_layers_to_abstract(hidden_layers) { auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); + // Eagle speculative decoding does not support dynamic_split_fuse mode + // because it requires hidden state interaction from main model to draft model + // to be implemented future + if (scheduler_configs.first.dynamic_split_fuse) { + Logger::warn( + "Note: disable dynamic split fuse for eagle3 speculative decoding" + ); + scheduler_configs.first.dynamic_split_fuse = false; + scheduler_configs.second.dynamic_split_fuse = false; + } auto main_model = main_model_desc.model; auto draft_model = draft_model_desc.model; From a9f68dbe63e94a1e5b66fd5aaf1bc8b8528b14f2 Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 24 Oct 2025 00:03:47 +0800 Subject: [PATCH 37/43] update per model d2t changes Signed-off-by: fishbell --- .github/workflows/linux.yml | 8 ++- .github/workflows/windows.yml | 8 ++- src/cpp/src/continuous_batching/pipeline.cpp | 27 --------- src/cpp/src/llm/pipeline.cpp | 3 - src/cpp/src/lora/adapter.cpp | 58 ++++++++++++++++++- src/cpp/src/{ => lora}/safetensors.c | 0 src/cpp/src/safe_tensor_wrapper.cpp | 43 -------------- src/cpp/src/safe_tensor_wrapper.hpp | 27 --------- .../speculative_decoding_eagle3_impl.cpp | 17 +++++- .../speculative_decoding_eagle3_impl.hpp | 7 +-- tests/python_tests/test_eagle3.py | 2 - 11 files changed, 86 insertions(+), 114 deletions(-) rename src/cpp/src/{ => lora}/safetensors.c (100%) delete mode 100644 src/cpp/src/safe_tensor_wrapper.cpp delete mode 100644 src/cpp/src/safe_tensor_wrapper.hpp diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 921af304b4..76419337b5 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -522,7 +522,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }} timeout: 240 - name: 'LLM & VLM' - cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_eagle3.py --override-ini cache_dir=/mount/caches/pytest/' + cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }} timeout: 180 - name: 'GGUF Reader tests' @@ -551,6 +551,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f + python -m pytest -v ./tests/python_tests/test_eagle3.py' + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).eagle3_speculative_decoding.test }} + timeout: 90 defaults: run: shell: bash diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index f4baab3817..cffe99ac1a 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -623,7 +623,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }} timeout: 240 - name: 'LLM & VLM' - cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_eagle3.py --override-ini cache_dir=/mount/caches/pytest/' + cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }} timeout: 180 - name: 'GGUF Reader tests' @@ -652,6 +652,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f + python -m pytest -v ./tests/python_tests/test_eagle3.py' + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).eagle3_speculative_decoding.test }} + timeout: 90 defaults: run: shell: pwsh diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index b8028957df..ec419f8253 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -16,7 +16,6 @@ #include "continuous_batching/timer.hpp" #include "utils.hpp" #include "visual_language/inputs_embedder.hpp" -#include "safe_tensor_wrapper.hpp" #include "json_utils.hpp" using namespace ov::genai; @@ -51,11 +50,6 @@ extract_eagle_mode_from_config(ov::AnyMap& config, const std::filesystem::path& OPENVINO_ASSERT(num_decoder_layers > 3, "num_decoder_layers is too small to deduce hidden layers for extraction"); eagle_rt_info.hidden_layers_list = { 2, num_decoder_layers / 2, num_decoder_layers - 3 }; } - if (config.find("dt_mapping_path") != config.end()) { - eagle_rt_info.dt_mapping_table = config.at("dt_mapping_path").as(); - eagle_rt_info.dt_mapping_table = eagle_rt_info.dt_mapping_table / "eagle3.safetensors"; - config.erase("dt_mapping_path"); - } } return eagle_rt_info; } @@ -107,13 +101,6 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); - // parse d2t from safe tensors - if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { - ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); - if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional - std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); - } - } } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); @@ -160,13 +147,6 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( // to be implemented future auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); - // parse d2t from safe tensors - if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { - ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); - if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional - std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); - } - } } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); @@ -215,13 +195,6 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); - // parse d2t from safe tensors - if (std::filesystem::exists(eagle_rt_info.dt_mapping_table)) { - ConstantMap constant_tensors = safetensor_to_constant_map(ov::read_tensor_data(eagle_rt_info.dt_mapping_table)); - if (constant_tensors.find("d2t") != constant_tensors.end()) { // d2t map can be optional - std::dynamic_pointer_cast(m_impl)->set_d2t_for_draft_decoding(constant_tensors["d2t"]); - } - } } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index b3bfddbc05..65dec743bd 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -39,9 +39,6 @@ inline void apply_eagle_rt_info(std::shared_ptr& model, ov::AnyMap& p properties["eagle3_mode"] = true; if (model->has_rt_info("hidden_layers_list")) properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); - if (!mapping_path.empty()) { - properties["dt_mapping_path"] = mapping_path; // d2t mapping path - } } } diff --git a/src/cpp/src/lora/adapter.cpp b/src/cpp/src/lora/adapter.cpp index fb4e427b2f..0d50156a13 100644 --- a/src/cpp/src/lora/adapter.cpp +++ b/src/cpp/src/lora/adapter.cpp @@ -40,10 +40,13 @@ #include "openvino/genai/lora_adapter.hpp" #include "utils.hpp" -#include "safe_tensor_wrapper.hpp" #include "lora/common.hpp" #include "lora/names_mapping.hpp" +extern "C" { + #include "safetensors.h" +} + // FIXME: Remove or move to a dedicated common header #ifdef NDEBUG #define DEBUG_PRINT(X) do {} while(false) @@ -66,6 +69,57 @@ using ConstantVector = std::vector>; using LoRANode = LoRAParts>; using LoRAPartsParser = LoRAParts(const std::string& name)>>; +// Converts Safetensors element type to OV element type. Only part of the types are supported. +ov::element::Type safetensors_to_ov_element_type (int dtype) { + switch(dtype) { + case SAFETENSORS_F32: + return ov::element::f32; + case SAFETENSORS_F16: + return ov::element::f16; + case SAFETENSORS_BF16: + return ov::element::bf16; + default: + OPENVINO_THROW("Not supported safetensors dtype: ", dtype); + } +} + +using ConstantMap = std::map>; + +// Safetensor file parser that deallocates temporary buffers automatically. +// Drop-in replacement for the third party safetensors_File struct. +struct AutoSafetensor: public safetensors_File { + ~AutoSafetensor () { + std::free(tensors); + std::free(metadata); + } +}; + +// The key in the map is a tensor name and the Constant uses a region of memory from the memory block. +// Each Constant holds a shared pointer to the block in the runtime info. +// The memory block will be deallocated when the last Constant is destroyed. +ConstantMap safetensor_to_constant_map(const ov::Tensor& safetensor) { + AutoSafetensor safe_tensors_file{}; + + OPENVINO_ASSERT(safetensors_file_init(safetensor.data(), safetensor.get_byte_size(), &safe_tensors_file) == nullptr, + "Cannot parse safetensor as a Safetensors file format. Safetensors file format is supported only" + ); + + ConstantMap tensors; + for (int i = 0; i < safe_tensors_file.num_tensors; i++) { + safetensors_TensorDescriptor tensor = safe_tensors_file.tensors[i]; + std::string name(tensor.name.ptr, tensor.name.ptr + tensor.name.len); + ov::Shape shape(tensor.shape, tensor.shape + tensor.n_dimensions); + void* ptr = tensor.ptr; // FIXME: needs a non-constant pointer because Tensor doesn't accept a constant pointer + + auto type = safetensors_to_ov_element_type(tensor.dtype); + auto constant = + std::make_shared(type, shape, ptr, nullptr); // wraps existing memory, no ownership + constant->get_rt_info()["__safetensors_buffer_holder"] = safetensor; // to automatically deallocate underlying memory buffer when last constant that holds it is destroyed + tensors[name] = constant; + } + return tensors; +} + // Reads a file with a given filename expecting Safetensors file format. // The file data is mmaped to tensor. ConstantMap read_safetensors(const std::filesystem::path& filename) { @@ -1713,4 +1767,4 @@ void AdapterConfig::set_adapters_and_alphas(const std::vector(), safetensor.get_byte_size(), &safe_tensors_file) == nullptr, - "Cannot parse safetensor as a Safetensors file format. Safetensors file format is supported only" - ); - - ConstantMap tensors; - for (int i = 0; i < safe_tensors_file.num_tensors; i++) { - safetensors_TensorDescriptor tensor = safe_tensors_file.tensors[i]; - std::string name(tensor.name.ptr, tensor.name.ptr + tensor.name.len); - ov::Shape shape(tensor.shape, tensor.shape + tensor.n_dimensions); - void* ptr = tensor.ptr; // FIXME: needs a non-constant pointer because Tensor doesn't accept a constant pointer - - auto type = safetensors_to_ov_element_type(tensor.dtype); - auto constant = - std::make_shared(type, shape, ptr, nullptr); // wraps existing memory, no ownership - constant->get_rt_info()["__safetensors_buffer_holder"] = safetensor; // to automatically deallocate underlying memory buffer when last constant that holds it is destroyed - tensors[name] = constant; - } - return tensors; -} \ No newline at end of file diff --git a/src/cpp/src/safe_tensor_wrapper.hpp b/src/cpp/src/safe_tensor_wrapper.hpp deleted file mode 100644 index dbb5f2c292..0000000000 --- a/src/cpp/src/safe_tensor_wrapper.hpp +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (C) 2023-2025 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -#include "openvino/runtime/core.hpp" -#include "openvino/op/constant.hpp" - -extern "C" { - #include "safetensors.h" -} - -// Converts Safetensors element type to OV element type. Only part of the types are supported. -ov::element::Type safetensors_to_ov_element_type (int dtype); - -using ConstantMap = std::map>; - -// Safetensor file parser that deallocates temporary buffers automatically. -// Drop-in replacement for the third party safetensors_File struct. -struct AutoSafetensor: public safetensors_File { - ~AutoSafetensor () { - std::free(tensors); - std::free(metadata); - } -}; - -// The key in the map is a tensor name and the Constant uses a region of memory from the memory block. -// Each Constant holds a shared pointer to the block in the runtime info. -// The memory block will be deallocated when the last Constant is destroyed. -ConstantMap safetensor_to_constant_map(const ov::Tensor& safetensor); \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp index ffd8808b80..cd3d80f8a1 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -51,6 +51,16 @@ void share_embedding_weights(std::shared_ptr& main_model, std::shared } } +std::shared_ptr extract_d2t_mapping_table(std::shared_ptr& model) { + // extract result nodes from model + for (const auto& result : model->get_results()) { + auto input_node = result->input_value(0).get_node_shared_ptr(); + if (ov::is_type(input_node) && input_node->get_friendly_name().find("d2t") != std::string::npos) { + return ov::as_type_ptr(input_node); + } + } + return nullptr; +} void extract_hidden_state_generic(std::shared_ptr& model, const std::vector& hidden_layers_to_abstract) { ov::pass::Manager pm; @@ -317,7 +327,9 @@ ContinuousBatchingPipeline::Eagle3DecodingImpl::Eagle3DecodingImpl(const ov::gen m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; // specific params update for eagle pipeline - update_eagle_pipeline_params(); + // check draft_model, retrieve d2t table if exists + auto d2t_tensor = extract_d2t_mapping_table(draft_model); + update_eagle_pipeline_params(d2t_tensor); } ov::Tensor ContinuousBatchingPipeline::Eagle3DecodingImpl::create_draft_input_ids(const ov::Tensor& original_input_ids) { @@ -339,7 +351,7 @@ ov::Tensor ContinuousBatchingPipeline::Eagle3DecodingImpl::create_draft_input_id return draft_input_ids; } -void ContinuousBatchingPipeline::Eagle3DecodingImpl::update_eagle_pipeline_params() { +void ContinuousBatchingPipeline::Eagle3DecodingImpl::update_eagle_pipeline_params(std::shared_ptr& d2t_tensor) { auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); m_main_eagle_pipeline->set_hidden_state_export_needed(true); @@ -348,6 +360,7 @@ void ContinuousBatchingPipeline::Eagle3DecodingImpl::update_eagle_pipeline_param m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); m_draft_eagle_pipeline->set_adjust_factor( m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); + m_draft_eagle_pipeline->set_d2t_for_draft_decoding(d2t_tensor); } GenerationHandle diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp index 4988586db2..012ec09af5 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp @@ -44,13 +44,8 @@ class ContinuousBatchingPipeline::Eagle3DecodingImpl : public ContinuousBatching GenerationHandle add_request(uint64_t request_id, const std::string& prompt, ov::genai::GenerationConfig sampling_params) override; - - void set_d2t_for_draft_decoding(std::shared_ptr& d2t_tensor) { - auto eagle_impl = std::static_pointer_cast(m_draft_pipeline); - eagle_impl->set_d2t_for_draft_decoding(d2t_tensor); - }; protected: - void update_eagle_pipeline_params(); + void update_eagle_pipeline_params(std::shared_ptr& d2t_tensor); ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); std::vector m_hidden_layers_to_abstract; }; diff --git a/tests/python_tests/test_eagle3.py b/tests/python_tests/test_eagle3.py index 9241e9d28c..80def5f6b0 100644 --- a/tests/python_tests/test_eagle3.py +++ b/tests/python_tests/test_eagle3.py @@ -22,7 +22,6 @@ def add(a, b): @pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) @pytest.mark.parametrize("main_device,draft_device", devices) @pytest.mark.precommit -@pytest.mark.skip(reason="CVS-174959 enable model conversion for eagle3 and enable the test") def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_device, prompt): # Download and convert model: main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) @@ -48,7 +47,6 @@ def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_dev @pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) @pytest.mark.parametrize("main_device,draft_device", devices) @pytest.mark.precommit -@pytest.mark.skip(reason="CVS-174959 enable model conversion for eagle3 and enable the test") def test_eagle3_sd_extended_perf_metrics(main_model, main_device, draft_model, draft_device, prompt): import time extended_perf_metrics = None From 3bcdfc31bcc90ba40e7f86755099c9f56058cee9 Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 24 Oct 2025 00:39:02 +0800 Subject: [PATCH 38/43] add missing line for unchanged file Signed-off-by: fishbell --- src/cpp/src/lora/adapter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/lora/adapter.cpp b/src/cpp/src/lora/adapter.cpp index 0d50156a13..2b186f3fad 100644 --- a/src/cpp/src/lora/adapter.cpp +++ b/src/cpp/src/lora/adapter.cpp @@ -1767,4 +1767,4 @@ void AdapterConfig::set_adapters_and_alphas(const std::vector Date: Fri, 24 Oct 2025 01:07:13 +0800 Subject: [PATCH 39/43] move test Signed-off-by: fishbell --- .github/workflows/linux.yml | 4 +- .github/workflows/windows.yml | 4 +- .../python_tests/test_continuous_batching.py | 70 ++++++++++++-- tests/python_tests/test_eagle3.py | 96 ------------------- 4 files changed, 65 insertions(+), 109 deletions(-) delete mode 100644 tests/python_tests/test_eagle3.py diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 3f63d3d9df..366eb56684 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -534,7 +534,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "not eagle3" ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -554,7 +554,7 @@ jobs: - name: 'EAGLE3 speculative decoding tests' cmd: | python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f - python -m pytest -v ./tests/python_tests/test_eagle3.py' + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).eagle3_speculative_decoding.test }} timeout: 90 defaults: diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index d0b9490484..45259c9b30 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -635,7 +635,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -s -v tests/python_tests/test_continuous_batching.py tests/python_tests/test_generation_config.py tests/python_tests/test_sampling.py tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -s -v tests/python_tests/test_continuous_batching.py -k "not eagle3" tests/python_tests/test_generation_config.py tests/python_tests/test_sampling.py tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -655,7 +655,7 @@ jobs: - name: 'EAGLE3 speculative decoding tests' cmd: | python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f - python -m pytest -v ./tests/python_tests/test_eagle3.py' + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).eagle3_speculative_decoding.test }} timeout: 90 defaults: diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index 90ea7ac316..2cdf3198c8 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -16,8 +16,9 @@ from utils.generation_config import get_greedy, get_beam_search, \ get_multinomial_all_parameters, get_multinomial_temperature_and_num_return_sequence, \ get_multinomial_temperature_and_top_k, get_multinomial_temperature, get_multinomial_temperature_and_top_p -from utils.hugging_face import download_and_convert_model -from utils.ov_genai_pipelines import create_ov_pipeline, create_ov_cb_pipeline, PipelineType, dict_to_scheduler_config, generate_and_compare, prepare_generation_config_by_pipe_type, GenerationChatInputsType +from utils.hugging_face import download_and_convert_model, run_hugging_face +from utils.ov_genai_pipelines import create_ov_pipeline, create_ov_cb_pipeline, PipelineType, dict_to_scheduler_config, generate_and_compare, prepare_generation_config_by_pipe_type, convert_decoded_results_to_generation_result, GenerationChatInputsType +from utils.comparation import compare_generation_results from data.models import get_chat_models_list from data.test_dataset import get_test_dataset @@ -489,22 +490,45 @@ def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_c return pipe, prompt, generation_config -def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType): +def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType, draft_model_id): _, _, model_path = download_and_convert_model(model_id) - ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type) + draft_model_path = None + if draft_model_id is not None: + _,_, draft_model_path = download_and_convert_model(draft_model_id) + ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type, draft_model_path = draft_model_path) return ov_pipe.generate([prompt], generation_config).extended_perf_metrics +eagle_models_and_input = [ + ("Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3", """Code: +def add(a, b): + return a + b +Question: Can you please add 2 and 3 +A:""")] + +speculative_cases = [ + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None, "Why is the Sun yellow?"), + eagle_models_and_input[0], +] @pytest.mark.parametrize("pipeline_type", [PipelineType.PAGED_ATTENTION, PipelineType.SPECULATIVE_DECODING]) +@pytest.mark.parametrize("main_model_id,draft_model_id, prompt", speculative_cases) @pytest.mark.precommit -def test_speculative_decoding_extended_perf_metrics(pipeline_type): +def test_speculative_decoding_extended_perf_metrics(pipeline_type, main_model_id, draft_model_id, prompt): import time start_time = time.perf_counter() - model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) - extended_perf_metrics = run_extended_perf_metrics_collection(model_id, generation_config, "Why is the Sun yellow?", pipeline_type) - total_time = (time.perf_counter() - start_time) * 1000 + extended_perf_metrics = None + if draft_model_id is None: + generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + extended_perf_metrics = run_extended_perf_metrics_collection(main_model_id, generation_config, prompt, pipeline_type, draft_model_id) + total_time = (time.perf_counter() - start_time) * 1000 + else: + if (pipeline_type == PipelineType.SPECULATIVE_DECODING): + generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + extended_perf_metrics = run_extended_perf_metrics_collection(main_model_id, generation_config, prompt, pipeline_type, draft_model_id) + total_time = (time.perf_counter() - start_time) * 1000 + + if (pipeline_type == PipelineType.SPECULATIVE_DECODING): assert not extended_perf_metrics is None assert not extended_perf_metrics.main_model_metrics is None @@ -542,3 +566,31 @@ def test_speculative_decoding_extended_perf_metrics(pipeline_type): assert std_gen_duration == 0 else: assert extended_perf_metrics is None + +devices = [ + ('CPU', 'CPU') +] +@pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) +@pytest.mark.parametrize("main_device,draft_device", devices) +@pytest.mark.precommit +def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_device, prompt): + # Download and convert model: + main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) + __, __, draft_model_path = download_and_convert_model(draft_model) + + # Create OpenVINO GenAI pipeline: + + ov_pipe = create_ov_pipeline(main_model_path, pipeline_type = PipelineType.SPECULATIVE_DECODING, draft_model_path = draft_model_path) + + # Run reference HF model: + ov_generation_config = GenerationConfig(max_new_tokens=20) + ref_gen_results = run_hugging_face(main_opt_model, main_hf_tokenizer, [prompt], ov_generation_config) + + # Run OpenVINO GenAI pipeline: + ov_decoded_results = ov_pipe.generate([prompt], ov_generation_config) + ov_gen_results = convert_decoded_results_to_generation_result(ov_decoded_results, 1, 1, False) + + del ov_pipe + + # Compare results: + compare_generation_results([prompt], ref_gen_results, ov_gen_results, ov_generation_config) \ No newline at end of file diff --git a/tests/python_tests/test_eagle3.py b/tests/python_tests/test_eagle3.py deleted file mode 100644 index 80def5f6b0..0000000000 --- a/tests/python_tests/test_eagle3.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (C) 2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -import pytest -import openvino_genai as ov_genai -from utils.hugging_face import download_and_convert_model, run_hugging_face -from utils.comparation import compare_generation_results -from utils.ov_genai_pipelines import create_ov_pipeline, PipelineType, convert_decoded_results_to_generation_result - -eagle_models_and_input = [ - ("Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3", """Code: -def add(a, b): - return a + b - -Question: Can you please add 2 and 3 -A:""")] - -devices = [ - ('CPU', 'CPU') -] -@pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) -@pytest.mark.parametrize("main_device,draft_device", devices) -@pytest.mark.precommit -def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_device, prompt): - # Download and convert model: - main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) - __, __, draft_model_path = download_and_convert_model(draft_model) - - # Create OpenVINO GenAI pipeline: - - ov_pipe = create_ov_pipeline(main_model_path, pipeline_type = PipelineType.SPECULATIVE_DECODING, draft_model_path = draft_model_path) - - # Run reference HF model: - ov_generation_config = ov_genai.GenerationConfig(max_new_tokens=20) - ref_gen_results = run_hugging_face(main_opt_model, main_hf_tokenizer, [prompt], ov_generation_config) - - # Run OpenVINO GenAI pipeline: - ov_decoded_results = ov_pipe.generate([prompt], ov_generation_config) - ov_gen_results = convert_decoded_results_to_generation_result(ov_decoded_results, 1, 1, False) - - del ov_pipe - - # Compare results: - compare_generation_results([prompt], ref_gen_results, ov_gen_results, ov_generation_config) - -@pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) -@pytest.mark.parametrize("main_device,draft_device", devices) -@pytest.mark.precommit -def test_eagle3_sd_extended_perf_metrics(main_model, main_device, draft_model, draft_device, prompt): - import time - extended_perf_metrics = None - start_time = time.perf_counter() - generation_config = ov_genai.GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) - # Download and convert model: - main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) - __, __, draft_model_path = download_and_convert_model(draft_model) - - # Create OpenVINO GenAI pipeline: - ov_pipe = create_ov_pipeline(main_model_path, pipeline_type = PipelineType.SPECULATIVE_DECODING, draft_model_path = draft_model_path) - extended_perf_metrics = ov_pipe.generate([prompt], generation_config).extended_perf_metrics - total_time = (time.perf_counter() - start_time) * 1000 - - assert not extended_perf_metrics is None - assert not extended_perf_metrics.main_model_metrics is None - assert not extended_perf_metrics.draft_model_metrics is None - assert extended_perf_metrics.get_num_accepted_tokens() > 0 - - num_generated_tokens_main = extended_perf_metrics.main_model_metrics.get_num_generated_tokens() - assert num_generated_tokens_main > 0 and num_generated_tokens_main <= generation_config.max_new_tokens - - # max num_generated_tokens for draft model will be reached if it will generate num_assistant_tokens at each step - # plus fist token, which was generated by main model - num_generated_tokens_draft = extended_perf_metrics.draft_model_metrics.get_num_generated_tokens() - assert num_generated_tokens_draft > 0 and num_generated_tokens_draft < ((generation_config.max_new_tokens - 1) * generation_config.num_assistant_tokens + 1) - - total_iteration_number_main = len(extended_perf_metrics.main_model_metrics.raw_metrics.m_durations) - assert total_iteration_number_main > 0 and total_iteration_number_main <= generation_config.max_new_tokens - - total_iteration_number_draft = len(extended_perf_metrics.draft_model_metrics.raw_metrics.m_durations) - assert total_iteration_number_draft > 0 and total_iteration_number_draft < ((generation_config.max_new_tokens - 1) * generation_config.num_assistant_tokens + 1) - - for model_metrics in [extended_perf_metrics.main_model_metrics, extended_perf_metrics.draft_model_metrics]: - mean_ttst, std_ttst = model_metrics.get_ttst() - assert (mean_ttst, std_ttst) == (model_metrics.get_ttst().mean, model_metrics.get_ttst().std) - assert mean_ttst > 0 and mean_ttst < model_metrics.get_ttft().mean - assert std_ttst == 0 - - mean_latency, std_latency = model_metrics.get_latency() - assert (mean_latency, std_latency) == (model_metrics.get_latency().mean, model_metrics.get_latency().std) - assert mean_latency > 0 and mean_latency < 1000.0 - - mean_gen_duration, std_gen_duration = model_metrics.get_generate_duration() - assert (mean_gen_duration, std_gen_duration) == (model_metrics.get_generate_duration().mean, model_metrics.get_generate_duration().std) - assert mean_gen_duration > 0 and mean_gen_duration < total_time - assert std_gen_duration == 0 \ No newline at end of file From fc2d11f8dc89597c294dd9033df4303a91196461 Mon Sep 17 00:00:00 2001 From: fishbell Date: Fri, 24 Oct 2025 16:49:36 +0800 Subject: [PATCH 40/43] try trigger test Signed-off-by: fishbell enable test Signed-off-by: fishbell --- .github/workflows/linux.yml | 2 +- .github/workflows/manylinux_2_28.yml | 8 +++++++- .github/workflows/windows.yml | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 366eb56684..06835962ba 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -555,7 +555,7 @@ jobs: cmd: | python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" - run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).eagle3_speculative_decoding.test }} + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} timeout: 90 defaults: run: diff --git a/.github/workflows/manylinux_2_28.yml b/.github/workflows/manylinux_2_28.yml index c9a7f26ec3..f1e9252c2e 100644 --- a/.github/workflows/manylinux_2_28.yml +++ b/.github/workflows/manylinux_2_28.yml @@ -472,7 +472,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py -n "not eagle3" ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -489,6 +489,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} + timeout: 90 defaults: run: shell: bash diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 45259c9b30..38a5821fec 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -656,7 +656,7 @@ jobs: cmd: | python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" - run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).eagle3_speculative_decoding.test }} + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} timeout: 90 defaults: run: From 1903c31079f66af49442a26850ca83779bcd5d51 Mon Sep 17 00:00:00 2001 From: fishbell Date: Mon, 27 Oct 2025 22:11:31 +0800 Subject: [PATCH 41/43] upgrade version Signed-off-by: fishbell --- .github/workflows/linux.yml | 2 +- .github/workflows/manylinux_2_28.yml | 2 +- .github/workflows/windows.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 7a3461419d..10aa490c29 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -553,7 +553,7 @@ jobs: timeout: 90 - name: 'EAGLE3 speculative decoding tests' cmd: | - python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} timeout: 90 diff --git a/.github/workflows/manylinux_2_28.yml b/.github/workflows/manylinux_2_28.yml index c3de2b2632..50e6257056 100644 --- a/.github/workflows/manylinux_2_28.yml +++ b/.github/workflows/manylinux_2_28.yml @@ -491,7 +491,7 @@ jobs: timeout: 90 - name: 'EAGLE3 speculative decoding tests' cmd: | - python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} timeout: 90 diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index d72f3ee254..6b2e7cf8dc 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -654,7 +654,7 @@ jobs: timeout: 90 - name: 'EAGLE3 speculative decoding tests' cmd: | - python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@e67abb1a20fb190b39c1dc0216cddb65b300210f + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} timeout: 90 From 3e5ec4323de01aa3d5658b29ebdb4e67f41574fb Mon Sep 17 00:00:00 2001 From: fishbell Date: Wed, 29 Oct 2025 18:04:48 +0800 Subject: [PATCH 42/43] apply review comments Signed-off-by: fishbell --- src/cpp/src/continuous_batching/pipeline.cpp | 1 + .../speculative_decoding_eagle3_impl.cpp | 4 +- .../samples/test_speculative_decoding_lm.py | 55 +++++++------------ .../python_tests/test_continuous_batching.py | 2 +- tests/python_tests/utils/hugging_face.py | 6 +- 5 files changed, 26 insertions(+), 42 deletions(-) diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index e143d0497d..666a5a530d 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -48,6 +48,7 @@ extract_eagle_mode_from_config(ov::AnyMap& config, const std::filesystem::path& int num_decoder_layers = 0; read_json_param(data, "num_hidden_layers", num_decoder_layers); OPENVINO_ASSERT(num_decoder_layers > 3, "num_decoder_layers is too small to deduce hidden layers for extraction"); + // below corresponds to : https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/modeling_llama_kv.py#L1138 eagle_rt_info.hidden_layers_list = { 2, num_decoder_layers / 2, num_decoder_layers - 3 }; } } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp index 08abff7c33..04b539b9dc 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -45,9 +45,9 @@ void share_embedding_weights(std::shared_ptr& main_model, std::shared try { draft_weight_node->output(0).replace(main_weight_node->output(0)); } catch (const std::exception& e) { - std::cerr << "Error: failed to import embedding weights from main model to draft model. Exception: " << e.what() << std::endl; + Logger::warn(std::string("Error: failed to import embedding weights from main model to draft model. Exception: ") + e.what()); } catch (...) { - std::cerr << "Error: failed to import embedding weights from main model to draft model due to unknown exception." << std::endl; + Logger::warn("Error: failed to import embedding weights from main model to draft model due to unknown exception."); } } diff --git a/tests/python_tests/samples/test_speculative_decoding_lm.py b/tests/python_tests/samples/test_speculative_decoding_lm.py index 09087b12e4..0b4ef570a2 100644 --- a/tests/python_tests/samples/test_speculative_decoding_lm.py +++ b/tests/python_tests/samples/test_speculative_decoding_lm.py @@ -10,6 +10,23 @@ convert_draft_model = convert_model +def _run_spec_case(convert_model, convert_draft_model, sample_args, env): + cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'speculative_decoding_lm') + cpp_command =[cpp_sample, convert_model, convert_draft_model, sample_args] + cpp_result = run_sample(cpp_command, env=env) + + py_script = os.path.join(SAMPLES_PY_DIR, "text_generation/speculative_decoding_lm.py") + py_command = [sys.executable, py_script, convert_model, convert_draft_model, sample_args] + py_result = run_sample(py_command, env=env) + + cpp_sample_ref = os.path.join(SAMPLES_CPP_DIR, 'greedy_causal_lm') + cpp_command_ref = [cpp_sample_ref, convert_model, sample_args] + cpp_result_ref = run_sample(cpp_command_ref, env=env) + + assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" + assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" + return cpp_result, py_result, cpp_result_ref + class TestSpeculativeDecodingLM: @pytest.mark.llm @pytest.mark.samples @@ -26,24 +43,7 @@ def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model pytest.xfail("Ticket 173586") env = os.environ.copy() env["OPENVINO_LOG_LEVEL"] = "0" - # Test CPP sample - cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'speculative_decoding_lm') - cpp_command =[cpp_sample, convert_model, convert_draft_model, sample_args] - cpp_result = run_sample(cpp_command, env=env) - - # Test Python sample - py_script = os.path.join(SAMPLES_PY_DIR, "text_generation/speculative_decoding_lm.py") - py_command = [sys.executable, py_script, convert_model, convert_draft_model, sample_args] - py_result = run_sample(py_command, env=env) - - # Greedy decoding - cpp_sample_ref = os.path.join(SAMPLES_CPP_DIR, 'greedy_causal_lm') - cpp_command_ref = [cpp_sample_ref, convert_model, sample_args] - cpp_result_ref = run_sample(cpp_command_ref, env=env) - - # Compare results - assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" - assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" + _run_spec_case(convert_model, convert_draft_model, sample_args, env) test_prompt = """Code: def add(a, b): @@ -66,21 +66,4 @@ def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model pytest.xfail("Ticket 173586") env = os.environ.copy() env["OPENVINO_LOG_LEVEL"] = "0" - # Test CPP sample - cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'speculative_decoding_lm') - cpp_command =[cpp_sample, convert_model, convert_draft_model, sample_args] - cpp_result = run_sample(cpp_command, env=env) - - # Test Python sample - py_script = os.path.join(SAMPLES_PY_DIR, "text_generation/speculative_decoding_lm.py") - py_command = [sys.executable, py_script, convert_model, convert_draft_model, sample_args] - py_result = run_sample(py_command, env=env) - - # Greedy decoding - cpp_sample_ref = os.path.join(SAMPLES_CPP_DIR, 'greedy_causal_lm') - cpp_command_ref = [cpp_sample_ref, convert_model, sample_args] - cpp_result_ref = run_sample(cpp_command_ref, env=env) - - # Compare results - assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" - assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" \ No newline at end of file + _run_spec_case(convert_model, convert_draft_model, sample_args, env) \ No newline at end of file diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index 2cdf3198c8..d034c5f728 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -490,7 +490,7 @@ def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_c return pipe, prompt, generation_config -def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType, draft_model_id): +def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType, draft_model_id: str): _, _, model_path = download_and_convert_model(model_id) draft_model_path = None if draft_model_id is not None: diff --git a/tests/python_tests/utils/hugging_face.py b/tests/python_tests/utils/hugging_face.py index d10dd2b693..2c4884cf28 100644 --- a/tests/python_tests/utils/hugging_face.py +++ b/tests/python_tests/utils/hugging_face.py @@ -166,13 +166,13 @@ def run_hugging_face( # download HF model or read converted model def get_huggingface_models(model_id: str | Path, model_class: Type[OVModel], local_files_only=False): - if (not "eagle3" in str(model_id).lower()): + if "eagle3" not in str(model_id).lower(): hf_tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, local_files_only=local_files_only)) opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) return opt_model, hf_tokenizer - else : + else: hf_tokenizer = None - opt_model = retry_request(lambda: model_class.from_pretrained(model_id, eagle3 = True, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) + opt_model = retry_request(lambda: model_class.from_pretrained(model_id, eagle3=True, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) return opt_model, hf_tokenizer From 433d8b33c44b0b54749a808d18c586f640a7d9fb Mon Sep 17 00:00:00 2001 From: fishbell Date: Thu, 30 Oct 2025 22:43:03 +0800 Subject: [PATCH 43/43] revert the tokenizer params Signed-off-by: fishbell --- tools/llm_bench/benchmark.py | 2 -- tools/llm_bench/task/text_generation.py | 12 ++---------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index 4004a74e06..b4e166121c 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -170,8 +170,6 @@ def get_argprser(): help="Path to file with Continuous Batching Scheduler settings or dict for Speculative decoding of draft model") parser.add_argument("--num_assistant_tokens", required=False, default=None, help="Config option num_assistant_tokens for Speculative decoding and Prompt Lookup decoding", type=int) - parser.add_argument("--eagle3_mode", action="store_true", - help="flag to indicate whether to use eagle3 for speculative decoding") parser.add_argument("--assistant_confidence_threshold", required=False, default=None, help="Config option assistant_confidence_threshold for Speculative decoding", type=float) parser.add_argument("--max_ngram_size", required=False, default=None, diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index bf9c9cfff8..5c1f6ab7c7 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -284,11 +284,7 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data ) tokenization_start = time.perf_counter() - if args.get("eagle3_mode"): - # eagle3 needs to disable special tokens to ensure compress rate - input_data = tokenizer.encode(input_text_list, add_special_tokens=False) - else: - input_data = tokenizer.encode(input_text_list) + input_data = tokenizer.encode(input_text_list) tokenization_end = time.perf_counter() tokenization_time = [(tokenization_end - tokenization_start) * 1000] @@ -455,11 +451,7 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg "If it is not expected, please specify --disable_prompt_permutation in your benchmarking command to disable this behavior" ) tok_encode_start = time.perf_counter() - if args.get("eagle3_mode"): - # eagle3 needs to disable special tokens to ensure compress rate - input_data = pipe_tokenizer.encode(input_text_list, add_special_tokens=False) - else: - input_data = pipe_tokenizer.encode(input_text_list) + input_data = pipe_tokenizer.encode(input_text_list) tok_encode_end = time.perf_counter() input_token_size = input_data.input_ids.shape[1] tok_encode_time = (tok_encode_end - tok_encode_start) * 1000