diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index c031b22b87..d565cb84ae 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -534,7 +534,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "not eagle3" ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -551,6 +551,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} + timeout: 90 defaults: run: shell: bash diff --git a/.github/workflows/manylinux_2_28.yml b/.github/workflows/manylinux_2_28.yml index 38fcd6d4e8..5cacec3d81 100644 --- a/.github/workflows/manylinux_2_28.yml +++ b/.github/workflows/manylinux_2_28.yml @@ -472,7 +472,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "not eagle3" ./tests/python_tests/test_generation_config.py ./tests/python_tests/test_sampling.py ./tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -489,6 +489,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} + timeout: 90 defaults: run: shell: bash diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index b0587ec758..f5c859b95a 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -623,7 +623,7 @@ jobs: run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).tokenizers.test }} timeout: 60 - name: 'API tests' - cmd: 'python -m pytest -s -v tests/python_tests/test_continuous_batching.py tests/python_tests/test_generation_config.py tests/python_tests/test_sampling.py tests/python_tests/test_text_streamer.py' + cmd: 'python -m pytest -s -v tests/python_tests/test_continuous_batching.py -k "not eagle3" tests/python_tests/test_generation_config.py tests/python_tests/test_sampling.py tests/python_tests/test_text_streamer.py' run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test || fromJSON(needs.smart_ci.outputs.affected_components).sampling.test || fromJSON(needs.smart_ci.outputs.affected_components).text_streamer.test }} timeout: 60 - name: 'Rag tests' @@ -640,6 +640,12 @@ jobs: python -m pytest -v ./tools/who_what_benchmark/tests -m nanollava run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).WWB.test }} timeout: 90 + - name: 'EAGLE3 speculative decoding tests' + cmd: | + python -m pip install git+https://github.com/xufang-lisa/optimum-intel.git@ea9607daf32919024cdd4390deec9693a7b64d23 + python -m pytest -v ./tests/python_tests/test_continuous_batching.py -k "eagle3" + run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).speculative_decoding.test }} + timeout: 90 defaults: run: shell: pwsh diff --git a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp index 6a5486e591..67a522d2fc 100644 --- a/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp +++ b/src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp @@ -65,13 +65,18 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline { class ContinuousBatchingImpl; class ContinuousBatchingForSpeculativeDecodingImpl; + class ContinuousBatchingForEagle3DecodingImpl; class ContinuousBatchingForPromptLookupImpl; class SpeculativeDecodingImpl; + class Eagle3DecodingImpl; class PromptLookupImpl; friend class ContinuousBatchingForSpeculativeDecodingImpl; + friend class ContinuousBatchingForPromptLookupImpl; + friend class ContinuousBatchingForEagle3DecodingImpl; friend class SpeculativeDecodingImpl; + friend class Eagle3DecodingImpl; friend class PromptLookupImpl; std::shared_ptr m_impl; diff --git a/src/cpp/src/continuous_batching/model_runner.hpp b/src/cpp/src/continuous_batching/model_runner.hpp index 2d8e814922..fded073f97 100644 --- a/src/cpp/src/continuous_batching/model_runner.hpp +++ b/src/cpp/src/continuous_batching/model_runner.hpp @@ -24,6 +24,27 @@ inline std::string get_paged_attention_score_output_for_decoder_layer(size_t dec return ss.str(); } +enum HiddenStateFlags : uint8_t { + HS_NONE = 0, + HS_EXPORT = 1 << 0, + HS_IMPORT = 1 << 1, + HS_INTERNAL = 1 << 2 +}; + +struct SequenceKey { + size_t request_id{}; + size_t grouped_sequence_id{}; + bool operator<(const SequenceKey& other) const { + return std::tie(request_id, grouped_sequence_id) < + std::tie(other.request_id, other.grouped_sequence_id); + } +}; + +struct HiddenStateRange { + size_t start_token_idx{}; + size_t length{}; +}; + /** * @brief Runs the LLM infer request, parsing the continuous batching scheduler output into proper inputs in terms of OV API (e.g. token input IDs, * KV cache block indices etc.) and returning the logit scores for the next token to be generated for each of the currently scheduled sequences. @@ -48,6 +69,11 @@ class ModelRunner { // Input shape: [N, conversation length]. // Output shape: [1, conversation length, hidden_size]. EmbeddingsModel::Ptr m_embedding; + uint8_t m_hidden_state_flags = HS_NONE; + std::map m_sequence_hidden_state_mapping; + // a container which use sequence group id and request id as key to store hidden states + std::map m_initial_hidden_states; // shape: [N, seq_len, hidden_size] + size_t m_adjust_factor = 1; // to adjust the hidden size of draft model input std::shared_ptr m_inputs_embedder; @@ -107,11 +133,19 @@ class ModelRunner { return m_request; } + void enable_hidden_state_export(bool on) { on ? m_hidden_state_flags |= HS_EXPORT : m_hidden_state_flags &= ~HS_EXPORT; } + void enable_hidden_state_import(bool on) { on ? m_hidden_state_flags |= HS_IMPORT : m_hidden_state_flags &= ~HS_IMPORT; } + void enable_hidden_state_internal(bool on) { on ? m_hidden_state_flags |= HS_INTERNAL : m_hidden_state_flags &= ~HS_INTERNAL; } + void set_inputs_embedder(const std::shared_ptr& inputs_embedder) { m_inputs_embedder = inputs_embedder; m_embedding = inputs_embedder->get_embedding_model(); } + void set_adjust_factor(size_t adjust_factor) { + m_adjust_factor = adjust_factor; + } + /** * @return A map of sequence IDs to vectors of ov::Tensor per-token attention scores. Each vector element is associated with its own * decoder layer, in order of their execution in the model. Each ov::Tensor has a shape of {N_k}, where N_k is the length of @@ -134,6 +168,10 @@ class ModelRunner { m_cache_rotation_deltas_for_each_layer = std::move(rotation_deltas_for_each_layer); } + void set_initial_hidden_state(uint64_t request_id, const ov::Tensor& hidden_state) { + m_initial_hidden_states[request_id] = hidden_state; + } + /** * Runs the forward inference call on the underlying LLM's ov::InferRequest, scheduling for inferencing tokens for given sequences * taking into account the supplied scheduler output struct. @@ -142,6 +180,7 @@ class ModelRunner { * @return An ov::Tensor with next-token logit scores for each sequence processed during this `forward` call. */ ov::Tensor forward(const std::vector & sequence_groups, const Scheduler::Output& scheduler_output) { + m_sequence_hidden_state_mapping.clear(); size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size(); size_t batch_size_in_sequences = 0; @@ -185,6 +224,12 @@ class ModelRunner { ov::Tensor score_aggregation_window = _get_or_resize_tensor(m_cached_score_aggregation_window, "score_aggregation_window", {batch_size_in_sequences}, ov::element::i32); + ov::Tensor hidden_state_input = _prepare_hidden_state_input(total_num_tokens, hidden_size); + float* hidden_state_data = nullptr; + if (hidden_state_input) { + hidden_state_data = hidden_state_input.data(); + } + ov::Tensor generated_ids_embeds; float *generated_ids_embeds_data = nullptr; @@ -234,6 +279,7 @@ class ModelRunner { matmul_gathering_is_available = true; } catch (const ov::Exception&) {} + size_t current_token_idx = 0; std::map> seq_id_to_skipped_blocks_map; size_t position_ids_idx = 0; for (size_t i = 0; i < num_sequence_groups; ++i) { @@ -265,6 +311,64 @@ class ModelRunner { output_seq_len = 0; Sequence::CPtr sequence = running_sequences[seq_idx]; + if (_is_hs_export()) { + size_t start_token_idx = current_token_idx; + size_t sequence_length = num_scheduled_tokens; + + SequenceKey key{sequence_group->get_request_id(), sequence->get_grouped_id()}; + m_sequence_hidden_state_mapping[key] = HiddenStateRange{start_token_idx, sequence_length}; + } + if (_is_hs_import()) { + auto it = m_initial_hidden_states.find(sequence_group->get_request_id()); + + if (it != m_initial_hidden_states.end()) { + const auto& stored_hidden_state = it->second; + + if (stored_hidden_state.get_size() > 0) { + auto stored_shape = stored_hidden_state.get_shape(); + + if (stored_shape.size() >= 2) { + size_t stored_seq_len = stored_shape[0]; + size_t stored_hidden_size = stored_shape[stored_shape.size() - 1]; + + if (stored_hidden_size == hidden_size) { + if (stored_seq_len == total_num_tokens) { + hidden_state_input = stored_hidden_state; // all tokens from eagle are accepted + } else { + size_t copy_length = std::min(stored_seq_len, num_scheduled_tokens); + + size_t source_start_idx = + stored_seq_len >= copy_length ? stored_seq_len - copy_length : 0; + _copy_roi_between_tensors(stored_hidden_state, source_start_idx, copy_length, hidden_state_input, current_token_idx); + } + } + } + } else { + OPENVINO_ASSERT(false, "missing hidden state from target model to eagle draft model"); + } + } + } else if (_is_hs_internal()) { + // fill hidden_state_data with m_hidden_states + if (hidden_state_data) { + OPENVINO_ASSERT(num_scheduled_tokens == 1, "unexpected num_scheduled_tokens in speculative drafting stage in eagle3 mode"); + std::memset(hidden_state_data + current_token_idx * hidden_size, + 0, + num_scheduled_tokens * hidden_size * sizeof(float)); + auto hidden_state = running_sequences[seq_idx]->get_hidden_state(); + if (hidden_state.get_size() > 0) { + auto shape = hidden_state.get_shape(); + if (shape.size() >= 2 && shape[shape.size() - 1] == hidden_size) { + size_t seq_len = shape[0]; + size_t copy_length = std::min(seq_len, num_scheduled_tokens); + + size_t src_start_idx = seq_len >= copy_length ? seq_len - copy_length : 0; + auto target_shape = ov::Shape{num_scheduled_tokens, 1, hidden_size}; + ov::Tensor target_base(ov::element::f32, target_shape, hidden_state_data + current_token_idx * hidden_size); + _copy_roi_between_tensors(hidden_state, src_start_idx, copy_length, target_base, 0); + } + } + } + } for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) { // compute token for current sequence if (sequence_group_type == SequenceGroupType::TOKENS) { @@ -343,6 +447,7 @@ class ModelRunner { *score_aggregation_window_data = 1; } } + current_token_idx += num_scheduled_tokens; past_lens_data += 1; subsequence_begins_data += 1; block_indices_begins_data += 1; @@ -367,6 +472,31 @@ class ModelRunner { m_request.set_tensor("token_type_ids", token_type_ids); } } + if (hidden_state_input && hidden_state_input.get_size() > 0) { + if (_is_hs_import()) { + try { + m_request.set_tensor("hidden_states", hidden_state_input); + auto shape = hidden_state_input.get_shape(); + shape[shape.size() - 1] = shape[shape.size() - 1] / m_adjust_factor; + ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); + auto fake_data = fake_tensor.data(); + std::memset(fake_data, 0, fake_tensor.get_byte_size()); + m_request.set_tensor("internal_hidden_states", fake_tensor); + } catch (const ov::Exception& e) { + } + } else { + try { + m_request.set_tensor("internal_hidden_states", hidden_state_input); + auto shape = hidden_state_input.get_shape(); + shape[shape.size() - 1] = shape[shape.size() - 1] * m_adjust_factor; + ov::Tensor fake_tensor = ov::Tensor(hidden_state_input.get_element_type(), shape); + auto fake_data = fake_tensor.data(); + std::memset(fake_data, 0, fake_tensor.get_byte_size()); + m_request.set_tensor("hidden_states", fake_tensor); + } catch (const ov::Exception& e) { + } + } + } if (position_ids.get_shape().size() == 3) { // flatten positions ids for 3D position ids case position_ids.set_shape({ov::shape_size(position_ids.get_shape())}); @@ -424,6 +554,23 @@ class ModelRunner { _reset_cache_rotation_coefficients(); + if (_is_hs_export()) { + try { + m_hidden_states = m_request.get_tensor("last_hidden_state"); + for (size_t i = 0; i < num_sequence_groups; ++i) { + size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; + SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id]; + std::vector running_sequences = sequence_group->get_running_sequences(); + for (size_t seq_idx = 0; seq_idx < running_sequences.size(); ++seq_idx) { + Sequence::Ptr sequence = running_sequences[seq_idx]; + sequence->update_hidden_state( + _get_hidden_state(sequence_group->get_request_id(), sequence->get_grouped_id())); + } + } + } catch (const ov::Exception&) { + m_hidden_states = ov::Tensor(); + } + } // return logits return m_request.get_tensor("logits"); } @@ -490,6 +637,116 @@ class ModelRunner { } private: + ov::Tensor m_hidden_states; + + // Hidden state flags and helpers + bool _is_hs_export() const { return m_hidden_state_flags & HS_EXPORT; } + bool _is_hs_import() const { return m_hidden_state_flags & HS_IMPORT; } + bool _is_hs_internal() const { return m_hidden_state_flags & HS_INTERNAL; } + + ov::Tensor _get_hidden_state(uint64_t request_id, uint64_t seq_grouped_id) const { + if (m_hidden_states.get_size() == 0) { + return ov::Tensor(); + } + + SequenceKey key{request_id, seq_grouped_id}; + const auto it = m_sequence_hidden_state_mapping.find(key); + if (it == m_sequence_hidden_state_mapping.end()) { + return ov::Tensor(); + } + + size_t start_idx = it->second.start_token_idx; + size_t length = it->second.length; + + auto shape = m_hidden_states.get_shape(); + if (shape.size() < 2) { + return ov::Tensor(); + } + + ov::Coordinate start_coord(shape.size(), 0); + ov::Coordinate end_coord(shape.size(), 0); + + start_coord[0] = start_idx; + end_coord[0] = start_idx + length; + + for (size_t i = 1; i < shape.size(); ++i) { + start_coord[i] = 0; + end_coord[i] = shape[i]; + } + + return ov::Tensor(m_hidden_states, start_coord, end_coord); + } + + ov::Tensor _prepare_hidden_state_input(size_t total_num_tokens, + size_t& hidden_size /*in/out*/) { + if (!(m_hidden_state_flags & (HS_IMPORT | HS_INTERNAL))) { + return {}; + } + + if (hidden_size == 0) { + for (const auto& kv : m_initial_hidden_states) { + const auto& initial_hidden_states = kv.second; + if (initial_hidden_states && initial_hidden_states.get_shape().size() >= 2) { + auto hidden_states_shape = initial_hidden_states.get_shape(); + hidden_size = hidden_states_shape.back(); + if (!(m_hidden_state_flags & HS_IMPORT)) { + hidden_size /= m_adjust_factor; + } + break; + } + } + } + if (hidden_size == 0) { + return {}; + } + + ov::Tensor hs(ov::element::f32, {total_num_tokens, 1, hidden_size}); + std::memset(hs.data(), 0, hs.get_byte_size()); + return hs; + } + + // Common helper to copy a contiguous slice (first-dim range) from src to dst using ROI tensors. + // src_start_idx: start index along src first dimension + // copy_length: number of elements along first dim to copy + // dst_base: destination base tensor (may be full buffer or a wrapper around a raw pointer) + // dst_first_dim_start: start index in first dimension of dst_base where copy should be placed + static void _copy_roi_between_tensors(const ov::Tensor& src, + size_t src_start_idx, + size_t copy_length, + const ov::Tensor& dst_base, + size_t dst_first_dim_start) { + if (copy_length == 0) { + return; + } + + // prepare source ROI coords + const auto src_shape = src.get_shape(); + OPENVINO_ASSERT(!src_shape.empty(), "source tensor rank is zero"); + ov::Coordinate src_start(src_shape.size(), 0), src_end(src_shape.size(), 0); + src_start[0] = src_start_idx; + src_end[0] = src_start_idx + copy_length; + for (size_t d = 1; d < src_shape.size(); ++d) { + src_start[d] = 0; + src_end[d] = src_shape[d]; + } + ov::Tensor src_roi(src, src_start, src_end); + + // prepare destination ROI coords + const auto dst_shape = dst_base.get_shape(); + OPENVINO_ASSERT(!dst_shape.empty(), "destination tensor rank is zero"); + ov::Coordinate tgt_start(dst_shape.size(), 0), tgt_end(dst_shape.size(), 0); + tgt_start[0] = dst_first_dim_start; + tgt_end[0] = dst_first_dim_start + copy_length; + for (size_t d = 1; d < dst_shape.size(); ++d) { + tgt_start[d] = 0; + tgt_end[d] = dst_shape[d]; + } + ov::Tensor tgt_roi(dst_base, tgt_start, tgt_end); + + // bulk copy + src_roi.copy_to(tgt_roi); + } + ov::Tensor _get_or_resize_tensor(ov::Tensor& cached_tensor, const std::string& tensor_name, const ov::Shape& required_shape, diff --git a/src/cpp/src/continuous_batching/pipeline.cpp b/src/cpp/src/continuous_batching/pipeline.cpp index 16eb169de7..eeac9a150b 100644 --- a/src/cpp/src/continuous_batching/pipeline.cpp +++ b/src/cpp/src/continuous_batching/pipeline.cpp @@ -11,14 +11,54 @@ #include "openvino/genai/tokenizer.hpp" #include "continuous_batching/pipeline_impl.hpp" #include "speculative_decoding/speculative_decoding_impl.hpp" +#include "speculative_decoding/speculative_decoding_eagle3_impl.hpp" #include "prompt_lookup/prompt_lookup_impl.hpp" #include "continuous_batching/timer.hpp" #include "utils.hpp" #include "visual_language/inputs_embedder.hpp" +#include "json_utils.hpp" using namespace ov::genai; namespace { +struct Eagle3RTInfo { + bool eagle3_mode = false; + std::vector hidden_layers_list; + std::filesystem::path dt_mapping_table; +}; + +Eagle3RTInfo +extract_eagle_mode_from_config(ov::AnyMap& config, const std::filesystem::path& models_path) { + Eagle3RTInfo eagle_rt_info; + if (config.find("eagle3_mode") != config.end()) { + eagle_rt_info.eagle3_mode = config.at("eagle3_mode").as(); + config.erase("eagle3_mode"); + if (config.find("hidden_layers_list") != config.end()) { + eagle_rt_info.hidden_layers_list = config.at("hidden_layers_list").as>(); + config.erase("hidden_layers_list"); + } else { + // compute the layers from number of hidden layers + auto config_file_path = models_path / "config.json"; + if (!std::filesystem::exists(config_file_path)) + OPENVINO_THROW("cannot deduce layers for hidden layer extraction"); + std::ifstream file(config_file_path); + + nlohmann::json data = nlohmann::json::parse(file); + using ov::genai::utils::read_json_param; + int num_decoder_layers = 0; + read_json_param(data, "num_hidden_layers", num_decoder_layers); + OPENVINO_ASSERT(num_decoder_layers > 3, "num_decoder_layers is too small to deduce hidden layers for extraction"); + // The following default hidden layer selection corresponds to the EAGLE reference implementation: + // https://github.com/SafeAILab/EAGLE/blob/0ea94696/eagle/model/modeling_llama_kv.py#L1138 + // These layers (2, num_decoder_layers / 2, num_decoder_layers - 3) are chosen to capture features from + // early, middle, and late stages of the decoder, as recommended by the EAGLE authors. + // If you wish to use different layers, provide the "hidden_layers_list" parameter in the config. + eagle_rt_info.hidden_layers_list = { 2, num_decoder_layers / 2, num_decoder_layers - 3 }; + } + } + return eagle_rt_info; +} + bool extract_prompt_lookup_from_config(ov::AnyMap& config) { bool res = false; @@ -45,6 +85,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p auto properties_without_draft_model = properties; auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, models_path); auto model = utils::read_model(models_path, properties); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); @@ -63,6 +104,10 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( const std::filesystem::p OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); @@ -87,13 +132,12 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); - + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, models_path); auto model = utils::read_model(models_path, properties_without_draft_model); auto [properties_without_draft_model_without_gguf, enable_save_ov_model] = utils::extract_gguf_properties(properties_without_draft_model); properties_without_draft_model_without_gguf[ov::cache_model_path.name()] = models_path; auto generation_config = utils::from_config_json_if_exists(models_path); - std::shared_ptr embedder; if (std::filesystem::exists(models_path / "openvino_text_embeddings_model.xml")) { embedder = std::make_shared(models_path, device, properties_without_draft_model_without_gguf); @@ -105,6 +149,13 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model_without_gguf, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + // Eagle speculative decoding does not support dynamic_split_fuse mode + // because it requires hidden state interaction from main model to draft model + // to be implemented future + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model_without_gguf, scheduler_config, generation_config); @@ -131,6 +182,7 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( auto properties_without_draft_model = properties; auto draft_model_desr = utils::extract_draft_model_from_config(properties_without_draft_model); auto is_prompt_lookup_enabled = extract_prompt_lookup_from_config(properties_without_draft_model); + auto eagle_rt_info = extract_eagle_mode_from_config(draft_model_desr.properties, std::filesystem::path(model_str)); auto model = utils::singleton_core().read_model(model_str, weights_tensor); auto rt_info = model->get_rt_info(); @@ -150,6 +202,10 @@ ContinuousBatchingPipeline::ContinuousBatchingPipeline( OPENVINO_ASSERT(draft_model_desr.model == nullptr, "Speculative decoding and prompt lookup decoding are mutually exclusive"); OPENVINO_ASSERT(embedder == nullptr, "Prompt lookup decoding is not supported for models with embeddings"); m_impl = std::make_shared(model, tokenizer, scheduler_config, device, properties_without_draft_model, generation_config); + } else if (draft_model_desr.model != nullptr && eagle_rt_info.eagle3_mode) { + OPENVINO_ASSERT(embedder == nullptr, "Eagle speculative decoding is not supported for models with embeddings"); + auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); + m_impl = std::make_shared(main_model_descr, draft_model_desr, eagle_rt_info.hidden_layers_list); } else if (draft_model_desr.model != nullptr) { OPENVINO_ASSERT(embedder == nullptr, "Speculative decoding is not supported for models with embeddings"); auto main_model_descr = ov::genai::ModelDesc(model, tokenizer, device, properties_without_draft_model, scheduler_config, generation_config); diff --git a/src/cpp/src/llm/pipeline.cpp b/src/cpp/src/llm/pipeline.cpp index e057d5da72..40f8576713 100644 --- a/src/cpp/src/llm/pipeline.cpp +++ b/src/cpp/src/llm/pipeline.cpp @@ -88,6 +88,20 @@ std::pair generation_config(const GenerationConfig& config) { return {utils::CONFIG_ARG_NAME, Any::make(config)}; } +inline void apply_eagle_rt_info(std::shared_ptr& model, ov::AnyMap& properties, const std::filesystem::path& mapping_path) { + if (model->has_rt_info("eagle3_mode") && model->get_rt_info("eagle3_mode")) { + properties["eagle3_mode"] = true; + if (model->has_rt_info("hidden_layers_list")) + properties["hidden_layers_list"] = model->get_rt_info>("hidden_layers_list"); + } +} + +inline void apply_eagle_rt_info(std::shared_ptr& model, + ov::AnyMap& properties, + const std::string& mapping_path) { + apply_eagle_rt_info(model, properties, std::filesystem::path(mapping_path)); +} + std::pair draft_model( const std::filesystem::path& models_path, const std::string& device, @@ -96,6 +110,7 @@ std::pair draft_model( std::filesystem::path openvino_model_name = "openvino_model.xml"; auto model = utils::singleton_core().read_model(models_path / openvino_model_name, {}, plugin_config); + apply_eagle_rt_info(model, plugin_config, models_path); auto generation_config = utils::from_config_json_if_exists(models_path); auto tokenizer = ov::genai::Tokenizer(models_path); return { utils::DRAFT_MODEL_ARG_NAME, Any::make(model, tokenizer, device, plugin_config, scheduler_config, generation_config) }; @@ -111,6 +126,7 @@ std::pair draft_model( auto [plugin_config, scheduler_config] = utils::extract_scheduler_config(properties); auto model = utils::singleton_core().read_model(model_str, weights_tensor); + apply_eagle_rt_info(model, plugin_config, model_str); return { utils::DRAFT_MODEL_ARG_NAME, Any::make(model, tokenizer, device, plugin_config, scheduler_config, generation_config) }; } diff --git a/src/cpp/src/sampling/sampler.cpp b/src/cpp/src/sampling/sampler.cpp index e1c1b2ffa5..6a58925657 100644 --- a/src/cpp/src/sampling/sampler.cpp +++ b/src/cpp/src/sampling/sampler.cpp @@ -853,6 +853,11 @@ SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr } } } + if (!is_validation_mode_enabled && m_draft2target_mapping) { // compute token offset for draft model in speculative sampling + ov::Tensor d2t_tensor = m_draft2target_mapping->get_tensor_view(); + auto d2t = d2t_tensor.data(); + sampled_token.m_index = sampled_token.m_index + (d2t? d2t[sampled_token.m_index] : 0); + } // flag to add sampled token to generated sequence or extend logit processors only bool is_extend_sequence = logit_token_offset == 0 || is_generate_n_tokens || !is_validation_passed; if (is_validation_mode_enabled && !is_extend_sequence) { diff --git a/src/cpp/src/sampling/sampler.hpp b/src/cpp/src/sampling/sampler.hpp index ffbbcac3e3..c4def2f871 100644 --- a/src/cpp/src/sampling/sampler.hpp +++ b/src/cpp/src/sampling/sampler.hpp @@ -99,6 +99,7 @@ class Sampler { Tokenizer m_tokenizer; ThreadPool m_thread_pool; + std::shared_ptr m_draft2target_mapping; // Tensor to store draft2target mapping for eagle model public: Sampler(const Sampler& rhs) = delete; Sampler(Sampler&& rhs) = delete; @@ -125,6 +126,10 @@ class Sampler { // pair with map with backend name and corresponding compiler init time, and vector of compile times for each concrete grammar std::pair, std::vector> get_structured_output_times(); void clear_structured_output_compile_times(); + + void set_d2t_for_decoding(std::shared_ptr& d2t) { + m_draft2target_mapping = d2t; + }; }; class Sampler::GroupBeamSearcher { diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index e1e2187498..54bd46c37a 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -44,6 +44,7 @@ class Sequence { LogProbs m_generated_log_probs; uint64_t m_grouped_id; uint64_t m_id = _get_next_global_sequence_id(); + ov::Tensor m_hidden_state = ov::Tensor(); SequenceStatus m_status = SequenceStatus::RUNNING; GenerationFinishReason m_finish_reason = GenerationFinishReason::NONE; float m_cumulative_log_prob = 0.0f; @@ -70,6 +71,7 @@ class Sequence { m_generated_ids(seq.m_generated_ids), m_generated_log_probs(seq.m_generated_log_probs), m_grouped_id(id), + m_hidden_state(seq.m_hidden_state), m_status(seq.m_status), m_cumulative_log_prob(seq.m_cumulative_log_prob), m_sequence_group(seq.m_sequence_group), @@ -142,6 +144,14 @@ class Sequence { m_generated_ids.push_back(token_id); } + void update_hidden_state(const ov::Tensor& tensor) { + m_hidden_state = tensor; + } + + ov::Tensor get_hidden_state() const { + return m_hidden_state; + } + // removes n last tokens and updates cumulative log prob // used to remove stop_string from the output void remove_last_tokens(int n) { @@ -643,7 +653,7 @@ class SequenceGroup : public std::enable_shared_from_this { m_num_validation_tokens = k; } - size_t get_num_tokens_to_validate() { + size_t get_num_tokens_to_validate() const { return m_num_validation_tokens; } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index 4853b8bac6..bcc00f9a8a 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -63,12 +63,48 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::get_ge for (const auto& sequence : request->get_running_sequences()) { const auto& sequence_id = sequence->get_grouped_id(); OPENVINO_ASSERT(!generated_request.count(sequence_id)); - generated_request.insert({{sequence_id, { sequence->get_generated_ids(), sequence->get_generated_log_probs() } }}); + generated_request.insert({{sequence_id, { sequence->get_generated_ids(), sequence->get_generated_log_probs(), sequence->get_hidden_state() } }}); } } return result; } +ov::Tensor truncate_hidden_state_from_end(const ov::Tensor& hidden_state, size_t tokens_to_remove) { + if (hidden_state.get_size() == 0 || tokens_to_remove == 0) { + return hidden_state; + } + + auto shape = hidden_state.get_shape(); + if (shape.size() < 2) { + return hidden_state; + } + + size_t seq_len_dim = 0; + size_t current_seq_len = shape[seq_len_dim]; + + if (tokens_to_remove >= current_seq_len) { + ov::Shape new_shape = shape; + new_shape[seq_len_dim] = 0; + return ov::Tensor(hidden_state.get_element_type(), new_shape); + } + + size_t new_seq_len = current_seq_len - tokens_to_remove; + + ov::Coordinate start_coord(shape.size(), 0); + ov::Coordinate end_coord(shape.size(), 0); + + for (size_t i = 0; i < shape.size(); ++i) { + start_coord[i] = 0; + if (i == seq_len_dim) { + end_coord[i] = new_seq_len; + } else { + end_coord[i] = shape[i]; + } + } + + return ov::Tensor(hidden_state, start_coord, end_coord); +} + // { min_len_of_prefix, min_length_of_candidate } std::pair get_prefix_len( @@ -227,6 +263,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update std::vector running_sequences = request->get_running_sequences(); OPENVINO_ASSERT(running_sequences.size() > 0); size_t min_generated_tokens, min_candidate_len; + size_t validate_length = 0; if (running_sequences.front()->get_generated_len() == 0 && !request->get_num_tokens_to_validate()) { m_sampler->create_logit_processor(request_id, request->get_sampling_parameters(), request->get_prompt_ids()); auto& logit_processor = m_sampler->get_logit_processor(request_id); @@ -234,6 +271,9 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update min_generated_tokens = result.inserted_tokens_cnt; running_sequences = request->get_running_sequences(); min_candidate_len = result.inserted_tokens_cnt; + if (eagle_mode_enabled && !m_is_validation_mode_enabled) + m_model_runner->set_initial_hidden_state(request_id, + candidates.begin()->second.hidden_states); } else { // update existing sequences by the candidates auto& logit_processor = m_sampler->get_logit_processor(request_id); @@ -252,6 +292,16 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update candidate_token_ids.resize(min_candidate_len); candidate_token_log_probs.resize(min_candidate_len); result.inserted_tokens_cnt = insert_tokens_to_sequence(running_sequence, candidate_token_ids, candidate_token_log_probs, logit_processor, is_update_logit_processor); + // handle hidden states for eagle mode + if (eagle_mode_enabled && !m_is_validation_mode_enabled && result.inserted_tokens_cnt > 0) { // update hidden states for draft model + // at least there should be one bonus token from main + auto& hidden_state = candidate_sequence.hidden_states; + ov::Tensor pruned_hidden_state = truncate_hidden_state_from_end(hidden_state, result.removed_tokens_cnt); + m_model_runner->set_initial_hidden_state(request_id, + pruned_hidden_state); + const auto& shape = pruned_hidden_state.get_shape(); + validate_length = shape.size() > 0 ? shape[0] : 0; + } } // we should update a logit processor just for draft model to generate the same tokens // logit processors of main model will be updated in sampler while validation mode @@ -266,14 +316,22 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::update updated_context_len = min_candidate_len + prompt_len, max_new_tokens = request->get_max_new_tokens(); size_t generated_len = request->get_context_len() >= request->get_prompt_len() ? request->get_context_len() - request->get_prompt_len() + 1 : 0; - if (generated_len > 0 && result.removed_tokens_cnt > 0) { - request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); + if (validate_length > 0) { + if (generated_len > 0) { + request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1 - (validate_length - 1)); + } + } else { // fast draft or main model for eagle speculative + if (generated_len > 0 && result.removed_tokens_cnt > 0) { + request->update_processed_tokens_num(num_processed_tokens - result.removed_tokens_cnt + 1); + } } - if (result.inserted_tokens_cnt > 0 && result.removed_tokens_cnt == 0) { + if (validate_length == 0 && result.inserted_tokens_cnt > 0 && result.removed_tokens_cnt == 0) { request->set_num_validated_tokens(result.inserted_tokens_cnt); + } else if (validate_length > 0) { + request->set_num_validated_tokens(validate_length - 1); // in generation stage } // to pause `draft_model` generation in case of `generated_len >= max_new_tokens - 1` to generate last token by `main_model` - if (!m_is_validation_mode_enabled) { + if (!m_is_validation_mode_enabled && result.inserted_tokens_cnt != 0) { bool pause_gen_status = false; generated_len -= result.removed_tokens_cnt; generated_len += result.inserted_tokens_cnt; @@ -328,6 +386,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m raw_perf_metrics.m_batch_sizes.emplace_back(num_generated_tokens); } + if (eagle_mode_enabled) + m_model_runner->enable_hidden_state_import(false); to_generate = false; for (auto& request : m_requests) { const auto& sampling_params = request->get_sampling_parameters(); @@ -351,5 +411,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::m to_generate |= request->can_generate_tokens(); } } + if (eagle_mode_enabled) + m_model_runner->enable_hidden_state_import(true); } } diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp index 40db6a2ddd..5d6d220028 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.hpp @@ -40,5 +40,56 @@ class ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl : protected: void finish_request(SequenceGroup::Ptr request); void _pull_awaiting_requests() override {}; + bool eagle_mode_enabled = false; +}; + +class ContinuousBatchingPipeline::ContinuousBatchingForEagle3DecodingImpl + : public ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl { +public: + ContinuousBatchingForEagle3DecodingImpl() = default; + + ContinuousBatchingForEagle3DecodingImpl(const std::shared_ptr& model, + const Tokenizer& tokenizer, + const GenerationConfig& generation_config, + const SchedulerConfig& scheduler_config, + const std::string& device, + const ov::AnyMap& plugin_config, + bool is_validation_mode_enabled) : ContinuousBatchingForSpeculativeDecodingImpl( + model, tokenizer, generation_config, + scheduler_config, device, plugin_config, + is_validation_mode_enabled) { + eagle_mode_enabled = true; + }; + + bool is_requests_empty(); + + void set_d2t_for_draft_decoding(std::shared_ptr& d2t) { + if (m_sampler) { + m_sampler->set_d2t_for_decoding(d2t); + } + } + void set_hidden_state_export_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->enable_hidden_state_export(is_needed); + } + } + + void set_hidden_state_import_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->enable_hidden_state_import(is_needed); + } + } + + void set_hidden_state_internal_needed(bool is_needed) { + if (m_model_runner) { + m_model_runner->enable_hidden_state_internal(is_needed); + } + } + + void set_adjust_factor(size_t adjust_factor) { + if (m_model_runner) { + m_model_runner->set_adjust_factor(adjust_factor); + } + } }; } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp new file mode 100644 index 0000000000..7cc8448325 --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.cpp @@ -0,0 +1,438 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +#include "speculative_decoding_eagle3_impl.hpp" +#include "logger.hpp" + +namespace ov::genai { +void share_embedding_weights(std::shared_ptr& main_model, std::shared_ptr& draft_model) { + // extract embedding weight from main model + auto find_embedding_gather = [](const std::shared_ptr& model) + -> std::shared_ptr { + constexpr size_t MIN_VOCAB_SIZE_THRESHOLD = 1000; + for (const auto& node : model->get_ordered_ops()) { + auto gather = std::dynamic_pointer_cast(node); + if (!gather) continue; + // [vocab, hidden_size] * [batch, seq_len] -> [batch, seq_len, hidden_size] + auto data_node = gather->input_value(0).get_node_shared_ptr(); + auto indices_node = gather->input_value(1).get_node_shared_ptr(); + if (!data_node || !indices_node) continue; + // indices_node should be on parameter path, maybe this is better rule + ov::PartialShape ps = data_node->get_output_partial_shape(0); + if (ps.rank().is_static() && ps.rank().get_length() >= 2) { + if (ps[0].is_static() && ps[0].get_length() > MIN_VOCAB_SIZE_THRESHOLD) { // Heuristic: vocab size > 1000 + return gather; + } + } + std::string fname = data_node->get_friendly_name(); + if (fname.find("embed_tokens") != std::string::npos || + fname.find("embedding") != std::string::npos) { + return gather; + } + } + return nullptr; + }; + auto main_gather = find_embedding_gather(main_model); + auto draft_gather = find_embedding_gather(draft_model); + if (!main_gather || !draft_gather) { + return; + } + auto main_weight_node = main_gather->input_value(0).get_node_shared_ptr(); + auto draft_weight_node = draft_gather->input_value(0).get_node_shared_ptr(); + + if (main_weight_node.get() == draft_weight_node.get()) { + return; + } + + try { + draft_weight_node->output(0).replace(main_weight_node->output(0)); + } catch (const std::exception& e) { + Logger::warn(std::string("Error: failed to import embedding weights from main model to draft model. Exception: ") + e.what()); + } catch (...) { + Logger::warn("Error: failed to import embedding weights from main model to draft model due to unknown exception."); + } +} + +std::shared_ptr extract_d2t_mapping_table(std::shared_ptr& model) { + // extract result nodes from model + for (const auto& result : model->get_results()) { + auto input_node = result->input_value(0).get_node_shared_ptr(); + if (ov::is_type(input_node) && input_node->get_friendly_name().find("d2t") != std::string::npos) { + return ov::as_type_ptr(input_node); + } + } + return nullptr; +} +void extract_hidden_state_generic(std::shared_ptr& model, + const std::vector& hidden_layers_to_abstract) { + ov::pass::Manager pm; + pm.register_pass(hidden_layers_to_abstract); + pm.run_passes(model); +} + +EagleModelTransform::EagleModelTransform(const std::vector& layers) : m_layer_ids(layers) { +} + +bool EagleModelTransform::run_on_model(const std::shared_ptr& model) { + // share the embedding weights from main model to draft model + m_new_parameters.clear(); + m_new_results.clear(); + if (m_layer_ids.size() == 1 && m_layer_ids[0] == -1) { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(m_new_results); + // input transform for draft + // here we apply a trick for the fc layer in draft model + manager.register_pass(m_new_parameters); + manager.run_passes(model); + + model->add_parameters(m_new_parameters); + model->add_results(m_new_results); + return true; + } else { + ov::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(m_layer_ids, m_hidden_layer_outputs); + manager.run_passes(model); + if (!m_hidden_layer_outputs.empty()) { + auto concat = std::make_shared(m_hidden_layer_outputs, -1); + concat->set_friendly_name("eagle3_hidden_states_concat"); + + auto result = std::make_shared(concat); + std::string output_name = "last_hidden_state"; + result->output(0).set_names({output_name}); + result->set_friendly_name(output_name); + model->add_results({result}); + return true; + } + } + + return false; +} + +EagleInputTransform::EagleInputTransform(std::vector>& params) { + register_matcher( + std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), + ([¶ms, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, params)) { + ++applied; + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "EagleTransform failed to apply"); + } + return false; + }) + ); +} +bool EagleInputTransform::apply(NodePtr node, std::vector>& params) { + if (ov::is_type(node)) { + auto matmul_node = ov::as_type_ptr(node); + // check the input of matmul node, if it is a node with name "hidden_states", then it's the node we want + auto input_node = matmul_node->get_input_node_shared_ptr(0); + if (!ov::as_type_ptr(input_node)) { + return false; + } + + auto shape = node->get_output_partial_shape(0); + auto internal_hidden_state = std::make_shared(node->get_element_type(), node->get_output_partial_shape(0)); + internal_hidden_state->output(0).set_names({"internal_hidden_states"}); + internal_hidden_state->set_friendly_name("internal_hidden_states"); + // create new eltwise node to add output of MatMul node and internal hidden state input from last cycle of itself + auto new_eltwise = std::make_shared(internal_hidden_state, matmul_node->output(0)); + ov::replace_node(matmul_node, new_eltwise); + params.push_back(internal_hidden_state); + return true; + } +} + +EagleBaseTransform::EagleBaseTransform(std::vector>& results) { + register_matcher( + std::make_shared(ov::pass::pattern::wrap_type(), this->get_type_info().name), + ([&results, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + try { + if (apply(node, results)) { + ++applied; + return true; + } + } catch (...) { + OPENVINO_ASSERT(false, "EagleTransform failed to apply"); + } + return false; + }) + ); +} + +std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node, + std::set& visited_nodes) { + if (visited_nodes.count(start_node.get())) { + return nullptr; + } + + visited_nodes.insert(start_node.get()); + + if (ov::is_type(start_node)) { + // check the input nodes of MatMul, if found Gather node, return the gather node, otherwise ,retrun the matmul node + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + if (ov::as_type_ptr(input_node)) { + return start_node; // return the Add node itself + } + } + } + + for (size_t i = 0; i < start_node->get_input_size(); ++i) { + auto input_node = start_node->get_input_node_shared_ptr(i); + if (!input_node) continue; + + auto result = find_last_residual_node(input_node, visited_nodes); + if (result) { + return result; + } + } + return nullptr; +} + +std::shared_ptr EagleBaseTransform::find_last_residual_node(const std::shared_ptr& start_node) { + std::set visited_nodes; + return find_last_residual_node(start_node, visited_nodes); +} + +bool EagleBaseTransform::apply(NodePtr node, std::vector>& results) { + { + // 1. without normalization layer 2. add extra input + if (ov::is_type(node)) { + // we are applying transformation to the last hidden state, eagle2 mode + NodePtr input_node = node->get_input_node_shared_ptr(0); + if (!input_node) { + return false; + } + auto last_residual_node = find_last_residual_node(input_node); + if (!last_residual_node) { + return false; + } + auto result = std::make_shared(last_residual_node); + std::string output_name = "last_hidden_state"; + result->output(0).set_names({output_name}); + result->set_friendly_name(output_name); + results.push_back(result); + return true; + } + return false; + } +} + +Eagle3Transform::Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs) : m_layers(layers) { + auto is_target_pattern = [&](const Output& output) { + auto add_node = ov::as_type_ptr(output.get_node_shared_ptr()); + auto add_node_name = add_node->get_friendly_name(); + if (add_node_name.find("self_attn") != std::string::npos) + return false; // Skip self-attention layers + bool layer_matched = false; + for (auto layer_idx : m_layers) { + if (add_node_name.find("layers." + std::to_string(layer_idx) + "/") != std::string::npos) { + layer_matched = true; + break; + } + } + + if (!layer_matched) { + return false; // Skip layers that are not in the specified layers + } + auto input0 = add_node->get_input_node_shared_ptr(1); + if (!input0 || !ov::is_type(input0)) { + return false; + } + auto matmul_node = input0; + auto matmul_input = matmul_node->get_input_node_shared_ptr(0); + if (!matmul_input) { + return false; + } + + bool has_multiply = ov::is_type(matmul_input); // ACT(up) dot gate + return has_multiply; + }; + + auto hidden_layer = ov::pass::pattern::wrap_type(is_target_pattern); + register_matcher(std::make_shared(hidden_layer, "Eagle3Transform::hidden_extraction"), + [&hidden_state_outputs, this](ov::pass::pattern::Matcher& m) { + auto node = m.get_match_root(); + if (ov::is_type(node)) { + hidden_state_outputs.push_back(node->output(0)); + return true; + } + return false; + } + ); +} + +ContinuousBatchingPipeline::Eagle3DecodingImpl::Eagle3DecodingImpl(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc, + const std::vector& hidden_layers) + : m_hidden_layers_to_abstract(hidden_layers) { + auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); + // Eagle speculative decoding does not support dynamic_split_fuse mode + // because it requires hidden state interaction from main model to draft model + // to be implemented future + if (scheduler_configs.first.dynamic_split_fuse) { + Logger::warn( + "Note: disable dynamic split fuse for eagle3 speculative decoding" + ); + scheduler_configs.first.dynamic_split_fuse = false; + scheduler_configs.second.dynamic_split_fuse = false; + } + auto main_model = main_model_desc.model; + auto draft_model = draft_model_desc.model; + + auto main_device = main_model_desc.device; + std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; + + ov::AnyMap draft_properties = + draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; + + // main and draft model can have different tokenizers + // to do: support retokenization: 154103 + Tokenizer main_model_tokenizer = main_model_desc.tokenizer; + Tokenizer draft_model_tokenizer = draft_model_desc.tokenizer; + m_tokenizer = main_model_tokenizer; + // for eagle model, we need to obtain hidden layer state as extra output + // apply transformations needed to run eagle model + // target model: hidden state extraction, draft model: hidden state import , hidden state extraction + // eagle3 specific : dt importing + share_embedding_weights(main_model, draft_model); + extract_hidden_state_generic(main_model, hidden_layers); + extract_hidden_state_generic(draft_model, { -1 }); + + // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode + m_main_pipeline = std::make_shared(main_model, + main_model_tokenizer, + main_model_desc.generation_config, + scheduler_configs.first, + main_device, + main_model_desc.properties, + true); + m_draft_pipeline = std::make_shared(draft_model, + draft_model_tokenizer, + draft_model_desc.generation_config, + scheduler_configs.second, + draft_device, + draft_properties, + false); + m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); + m_perf_metrics.raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; + m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + + // specific params update for eagle pipeline + // check draft_model, retrieve d2t table if exists + auto d2t_tensor = extract_d2t_mapping_table(draft_model); + update_eagle_pipeline_params(d2t_tensor); +} + +ov::Tensor ContinuousBatchingPipeline::Eagle3DecodingImpl::create_draft_input_ids(const ov::Tensor& original_input_ids) { + auto shape = original_input_ids.get_shape(); + if (shape.size() == 0 || shape.back() <= 1) { + return ov::Tensor(original_input_ids); + } + + size_t original_length = shape.back(); + size_t new_length = original_length - 1; + + ov::Tensor draft_input_ids(ov::element::i64, {1, new_length}); + + const int64_t* src_data = original_input_ids.data(); + int64_t* dst_data = draft_input_ids.data(); + + std::copy(src_data + 1, src_data + original_length, dst_data); + + return draft_input_ids; +} + +void ContinuousBatchingPipeline::Eagle3DecodingImpl::update_eagle_pipeline_params(std::shared_ptr& d2t_tensor) { + auto m_main_eagle_pipeline = std::dynamic_pointer_cast(m_main_pipeline); + auto m_draft_eagle_pipeline = std::dynamic_pointer_cast(m_draft_pipeline); + m_main_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_export_needed(true); + m_draft_eagle_pipeline->set_hidden_state_import_needed(true); + m_draft_eagle_pipeline->set_hidden_state_internal_needed(true); + m_draft_eagle_pipeline->set_adjust_factor( + m_hidden_layers_to_abstract.size() > 0 ? m_hidden_layers_to_abstract.size() : 1); + m_draft_eagle_pipeline->set_d2t_for_draft_decoding(d2t_tensor); +} + +GenerationHandle +ContinuousBatchingPipeline::Eagle3DecodingImpl::add_request(uint64_t request_id, + const ov::Tensor& input_ids, + const ov::genai::GenerationConfig& sampling_params, + std::optional token_type_ids) { + std::lock_guard lock(m_draft_generations_mutex); + auto draft_sampling_params = sampling_params; + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + // remove first token from input_ids to create draft_input_ids + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params, token_type_ids)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params, token_type_ids); +} + +GenerationHandle +ContinuousBatchingPipeline::Eagle3DecodingImpl::add_request(uint64_t request_id, + const std::string& prompt, + const ov::genai::GenerationConfig& sampling_params) { + std::lock_guard lock(m_draft_generations_mutex); + auto draft_sampling_params = sampling_params; + draft_sampling_params.ignore_eos = true; + draft_sampling_params.stop_strings = {}; + // remove first token from input_ids to create draft_input_ids + // add_special_tokens is false for better compress rate + auto input_ids = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)).input_ids; + ov::Tensor draft_input_ids = create_draft_input_ids(input_ids); + m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, draft_input_ids, draft_sampling_params)}); + return m_main_pipeline->add_request(request_id, input_ids, sampling_params); +} + +std::vector ContinuousBatchingPipeline::Eagle3DecodingImpl::generate( + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + const std::optional>& token_type_ids, + const std::optional>>>& position_ids) { + GenerateStrategy strategy; + strategy.prepare_request = [this](size_t, + const ov::Tensor& in_ids, + GenerationConfig& main_cfg, + GenerationConfig& draft_cfg, + ov::Tensor& main_in, + ov::Tensor& draft_in) { + OPENVINO_ASSERT(main_cfg.assistant_confidence_threshold == 0.f, + "Eagle3 only supports num_assistant_tokens (assistant_confidence_threshold must be 0.f)"); + if (main_cfg.num_assistant_tokens == 0) { + main_cfg.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; + draft_cfg.num_assistant_tokens = main_cfg.num_assistant_tokens; + } + draft_cfg.ignore_eos = true; + draft_cfg.stop_strings = {}; + main_in = in_ids; + draft_in = create_draft_input_ids(in_ids); + }; + + strategy.check_streaming = [](const std::shared_ptr& streamer_ptr, + const std::vector& input_ids, + const std::vector& sampling_params) { + OPENVINO_ASSERT(!streamer_ptr->has_callback() || + (input_ids.size() == 1 && + (sampling_params[0].is_greedy_decoding())), + "Eagle3 streaming only supports batch size=1 with greedy"); + }; + strategy.start_timer = [](){ + return std::chrono::steady_clock::now(); + }; + strategy.stop_timer = [](TimePoint start){ + return PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start); + }; + + return generate_common(this, input_ids, sampling_params, streamer, token_type_ids, strategy); +} +} // namespace ov::genai \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp new file mode 100644 index 0000000000..f8aab78eff --- /dev/null +++ b/src/cpp/src/speculative_decoding/speculative_decoding_eagle3_impl.hpp @@ -0,0 +1,107 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "speculative_decoding_impl.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/result.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/manager.hpp" + +namespace ov::genai { +class ContinuousBatchingPipeline::Eagle3DecodingImpl : public ContinuousBatchingPipeline::SpeculativeDecodingImpl { +public: + template + friend std::vector generate_common( + Impl*, + const std::vector&, + const std::vector&, + const StreamerVariant&, + std::optional>, + GenerateStrategy&); + Eagle3DecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc, const std::vector& hidden_layers_to_abstract); + + std::vector + generate(const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + const std::optional>& token_type_ids = std::nullopt, + const std::optional>>>& position_ids = std::nullopt) override; + + GenerationHandle add_request(uint64_t request_id, + const ov::Tensor& input_ids, + const ov::genai::GenerationConfig& sampling_params, + std::optional token_type_ids = std::nullopt) override; + + GenerationHandle add_request(uint64_t request_id, + const std::string& prompt, + const ov::genai::GenerationConfig& sampling_params) override; +protected: + void update_eagle_pipeline_params(std::shared_ptr& d2t_tensor); + ov::Tensor create_draft_input_ids(const ov::Tensor& original_input_ids); + std::vector m_hidden_layers_to_abstract; +}; + +using NodePtr = std::shared_ptr; +using namespace ov::op; + +class EagleBaseTransform : public ov::pass::MatcherPass { +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("EagleBaseTransform"); + EagleBaseTransform(std::vector>& results); + + ~EagleBaseTransform() = default; + +private: + bool apply(NodePtr node, std::vector>& results); + size_t applied = 0; + std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node); + std::shared_ptr find_last_residual_node(const std::shared_ptr& start_node, + std::set& visited_nodes); +}; +class EagleInputTransform : public ov::pass::MatcherPass { // eagle3 specific for draft model +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("EagleInputTransform"); + EagleInputTransform(std::vector>& params); + + ~EagleInputTransform() = default; + +private: + bool apply(NodePtr node, std::vector>& params); + size_t applied = 0; +}; +class Eagle3Transform : public ov::pass::MatcherPass { +public: + using NodePtr = std::shared_ptr; + OPENVINO_MATCHER_PASS_RTTI("Eagle3Transform"); + Eagle3Transform(const std::vector& layers, std::vector>& hidden_state_outputs); + + ~Eagle3Transform() = default; + +private: + std::vector m_layers; // layers to be abstracted +}; + +class EagleModelTransform : public ov::pass::ModelPass { +public: + EagleModelTransform(const std::vector& layer_ids); + bool run_on_model(const std::shared_ptr& model) override; + +private: + const std::vector m_layer_ids; + std::vector> m_new_results; + std::vector> m_new_parameters; + std::vector> m_hidden_layer_outputs; +}; +} diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 1ca63e9de7..4293e2e03c 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -25,26 +25,19 @@ bool are_tokenizers_equal(Tokenizer& lhs, Tokenizer& rhs) { lhs.get_bos_token_id() == rhs.get_bos_token_id() && lhs.get_pad_token_id() == rhs.get_pad_token_id(); } -ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, - const ov::genai::ModelDesc& draft_model_desc) { - auto main_model = main_model_desc.model; - auto draft_model = draft_model_desc.model; - - OPENVINO_ASSERT(main_model != nullptr, "Main model cannot be null"); - OPENVINO_ASSERT(draft_model != nullptr, "Draft model cannot be null"); +std::pair +ContinuousBatchingPipeline::SpeculativeDecodingImpl::init_speculative_models(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc) { + OPENVINO_ASSERT(main_model_desc.model != nullptr, "Main model cannot be null"); + OPENVINO_ASSERT(draft_model_desc.model != nullptr, "Draft model cannot be null"); + utils::apply_paged_attention_transformations(main_model_desc.model, main_model_desc.scheduler_config.use_cache_eviction); + utils::apply_paged_attention_transformations(draft_model_desc.model, main_model_desc.scheduler_config.use_cache_eviction); - auto main_scheduler_config = main_model_desc.scheduler_config; - auto main_device = main_model_desc.device; + utils::apply_gather_before_matmul_transformation(main_model_desc.model); + utils::apply_gather_before_matmul_transformation(draft_model_desc.model); - utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction); - utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction); - - utils::apply_gather_before_matmul_transformation(main_model); - utils::apply_gather_before_matmul_transformation(draft_model); - - std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; bool is_draft_scheduler_undefined = draft_model_desc.scheduler_config == SchedulerConfig(); + auto main_scheduler_config = main_model_desc.scheduler_config; ov::genai::SchedulerConfig main_scheduler_config_updated = main_scheduler_config, draft_scheduler_config = is_draft_scheduler_undefined ? main_scheduler_config : draft_model_desc.scheduler_config; @@ -63,8 +56,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con } return total_hidden_size; }; - float main_model_hidden_size = compute_total_hidden_size(main_model), - draft_model_hidden_size = compute_total_hidden_size(draft_model); + float main_model_hidden_size = compute_total_hidden_size(main_model_desc.model), + draft_model_hidden_size = compute_total_hidden_size(draft_model_desc.model); auto k = draft_model_hidden_size / (main_model_hidden_size + draft_model_hidden_size); // TODO: work with KV blocks as it will be more precise instead of GBs @@ -82,8 +75,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con draft_scheduler_config.max_num_batched_tokens = main_scheduler_config_updated.max_num_batched_tokens; } - ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; + return std::make_pair(main_scheduler_config_updated, draft_scheduler_config); +} +ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, + const ov::genai::ModelDesc& draft_model_desc) { + auto scheduler_configs = init_speculative_models(main_model_desc, draft_model_desc); + + auto main_device = main_model_desc.device; + std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device; // main and draft model can have different tokenizers // to do: support retokenization: 154103 Tokenizer main_model_tokenizer = main_model_desc.tokenizer; @@ -91,16 +91,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con // todo: remove this condition after support of CVS-154103 OPENVINO_ASSERT(are_tokenizers_equal(main_model_tokenizer, draft_model_tokenizer), "Tokenizers for draft and main models are different!"); - m_tokenizer = main_model_tokenizer; - + ov::AnyMap draft_properties = draft_model_desc.properties.empty() ? main_model_desc.properties : draft_model_desc.properties; // to create `main_pipeline` with enabled validation_mode and `draft_pipeline` with disabled validation mode m_main_pipeline = std::make_shared( - main_model, main_model_tokenizer, main_model_desc.generation_config, - main_scheduler_config_updated, main_device, main_model_desc.properties, true); + main_model_desc.model, main_model_tokenizer, main_model_desc.generation_config, + scheduler_configs.first, main_device, main_model_desc.properties, true); m_draft_pipeline = std::make_shared( - draft_model, draft_model_tokenizer, draft_model_desc.generation_config, - draft_scheduler_config, draft_device, draft_properties, false); + draft_model_desc.model, draft_model_tokenizer, draft_model_desc.generation_config, + scheduler_configs.second, draft_device, draft_properties, false); m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; @@ -241,116 +240,37 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< const StreamerVariant& streamer, const std::optional>& token_type_ids, const std::optional>>>& position_ids) { - OPENVINO_ASSERT(!token_type_ids.has_value()); - m_perf_metrics = ov::genai::SDPerModelsPerfMetrics(); - m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; - - OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); - OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); - - const auto generate_start = std::chrono::steady_clock::now(); - - // checks that all requests has the same LoRA adapters property value - for (size_t i = 1; i < sampling_params.size(); ++i) { - OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, - "LoRA adapters value must be the same for all requests"); - } - m_main_pipeline->set_adapters(sampling_params[0].adapters); - m_draft_pipeline->set_adapters(sampling_params[0].adapters); - - const auto streamer_ptr = std::make_shared(streamer, m_tokenizer); - - OPENVINO_ASSERT(!streamer_ptr->has_callback() || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), - "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); - - std::vector main_generations; - for (size_t request_id = 0; request_id < input_ids.size(); ++request_id) { - OPENVINO_ASSERT(1 == input_ids[request_id].get_shape().at(0), "Use multiple tensors to pass a batch."); - auto main_sampling_params = sampling_params[request_id]; - if (main_sampling_params.assistant_confidence_threshold == 0.f) { - if (main_sampling_params.num_assistant_tokens == 0) { - main_sampling_params.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; + GenerateStrategy strategy; + strategy.prepare_request = [this](size_t, + const ov::Tensor& in_ids, + GenerationConfig& main_cfg, + GenerationConfig& draft_cfg, + ov::Tensor& main_in, + ov::Tensor& draft_in) { + if (main_cfg.assistant_confidence_threshold == 0.f) { + if (main_cfg.num_assistant_tokens == 0) { + main_cfg.num_assistant_tokens = m_main_pipeline->default_num_assistant_tokens; } } - main_generations.push_back(m_main_pipeline->add_request(request_id, input_ids[request_id], main_sampling_params)); - - auto draft_sampling_params = main_sampling_params; - // set the parameters do not stop draft generation without stopping of the same request for main pipeline - draft_sampling_params.ignore_eos = true; - draft_sampling_params.stop_strings = {}; - std::lock_guard lock(m_draft_generations_mutex); - m_draft_generations.insert({request_id, m_draft_pipeline->add_request(request_id, input_ids[request_id], draft_sampling_params)}); - } - auto all_requests = get_awaiting_requests(); - - GenerationHandle& generation = main_generations.at(0); - - streamer_ptr->start(); - - while (has_non_finished_requests()) { - try { - step(); - } catch (...) { - drop_requests(); // remove all requests from pipeline state in case of exception - streamer_ptr->end(); - std::rethrow_exception(std::current_exception()); - } - stream_tokens(streamer_ptr, generation); - } - - // waiting for completion of streaming - streamer_ptr->end(); - - OPENVINO_ASSERT(is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); - - std::vector results; - results.reserve(all_requests.size()); - - m_perf_metrics.draft_model_metrics.raw_metrics = m_draft_pipeline->raw_perf_metrics; - - const auto generate_end = std::chrono::steady_clock::now(); - const auto generate_duration = PerfMetrics::get_microsec(generate_end - generate_start); - - for (size_t request_id = 0; request_id < all_requests.size(); ++request_id) { - const auto& request = all_requests[request_id]; - auto sampling_params = request->get_sampling_parameters(); - const auto& sequences = request->get_finished_sequences(); - size_t num_outputs = std::min(sampling_params.num_return_sequences, sequences.size()); - - EncodedGenerationResult result; - result.m_request_id = request_id; - result.m_generation_ids.resize(num_outputs); - result.m_scores.resize(num_outputs); - result.m_status = request->get_generation_stream()->get_status(); - - for (size_t i = 0; i < num_outputs; ++i) { - const auto & sequence = sequences[i]; - const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob(); - const auto & generated_ids = sequence->get_generated_ids(); - - if (sampling_params.echo) { - result.m_generation_ids[i] = request->get_prompt_ids(); - } - std::copy(generated_ids.begin(), generated_ids.end(), std::back_inserter(result.m_generation_ids[i])); - result.m_scores[i] = score; - } - - result.m_status = main_generations[request_id]->get_status(); - - // The same perf metrics for each sequence, only tokenization/detokenization will differ. - m_perf_metrics.raw_metrics.generate_durations.clear(); - m_perf_metrics.raw_metrics.generate_durations.emplace_back(generate_duration); - m_perf_metrics.num_input_tokens = request->get_prompt_len(); - m_perf_metrics.evaluate_statistics(generate_start); - - result.perf_metrics = m_perf_metrics; - result.extended_perf_metrics = std::make_shared(m_perf_metrics); - results.push_back(std::move(result)); - } - - OPENVINO_ASSERT(results.size() == input_ids.size()); - - return results; + draft_cfg.ignore_eos = true; + draft_cfg.stop_strings = {}; + main_in = in_ids; + draft_in = in_ids; + }; + strategy.check_streaming = [](const std::shared_ptr& streamer_ptr, + const std::vector& input_ids, + const std::vector& sampling_params) { + OPENVINO_ASSERT(!streamer_ptr->has_callback() || + (input_ids.size() == 1 && + (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial())), + "Streaming only supports batch size=1 with greedy/multinomial"); + }; + strategy.start_timer = [](){ return std::chrono::steady_clock::now(); }; + strategy.stop_timer = [](TimePoint start){ + return PerfMetrics::get_microsec(std::chrono::steady_clock::now() - start); + }; + + return generate_common(this, input_ids, sampling_params, streamer, token_type_ids, strategy); } SpeculativeDecodingMetrics @@ -375,4 +295,5 @@ std::vector ContinuousBatchingPipeline::SpeculativeDecodingI OPENVINO_ASSERT(main_awaiting_requests.size() == draft_awaiting_requests.size()); return main_awaiting_requests; } + } diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp index 87ae8ab60d..d913fb2979 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp @@ -11,6 +11,129 @@ #include "utils.hpp" namespace ov::genai { +struct GenerateStrategy { + std::function prepare_request; + std::function&, + const std::vector&, + const std::vector&)> check_streaming; + std::function start_timer; + std::function stop_timer; +}; + +template +std::vector generate_common( + Impl* self, + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids, + GenerateStrategy& strategy) { + + OPENVINO_ASSERT(!token_type_ids.has_value()); + self->perf_metrics() = ov::genai::SDPerModelsPerfMetrics(); + self->draft_pipeline()->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }}; + + OPENVINO_ASSERT(!self->has_non_finished_requests(), + "Generate cannot be called while ContinuousBatchingPipeline is already running"); + OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); + + auto t_start = strategy.start_timer(); + + for (size_t i = 1; i < sampling_params.size(); ++i) { + OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters, + "LoRA adapters must be same for all requests"); + } + self->main_pipeline()->set_adapters(sampling_params[0].adapters); + self->draft_pipeline()->set_adapters(sampling_params[0].adapters); + + auto streamer_ptr = std::make_shared(streamer, self->tokenizer()); + + strategy.check_streaming(streamer_ptr, input_ids, sampling_params); + + std::vector main_generations; + { + std::lock_guard lock(self->draft_generations_mutex()); + for (size_t rid = 0; rid < input_ids.size(); ++rid) { + GenerationConfig main_cfg = sampling_params[rid]; + GenerationConfig draft_cfg = main_cfg; + ov::Tensor main_in, draft_in; + strategy.prepare_request(rid, input_ids[rid], + main_cfg, draft_cfg, + main_in, draft_in); + main_generations.push_back(self->main_pipeline()->add_request(rid, main_in, main_cfg)); + self->draft_generations().insert({rid, + self->draft_pipeline()->add_request(rid, draft_in, draft_cfg)}); + } + } + + auto all_requests = self->get_awaiting_requests(); + GenerationHandle& generation = main_generations.at(0); + + streamer_ptr->start(); + while (self->has_non_finished_requests()) { + try { + self->step(); + } catch (...) { + self->drop_requests(); + streamer_ptr->end(); + std::rethrow_exception(std::current_exception()); + } + self->stream_tokens(streamer_ptr, generation); + } + streamer_ptr->end(); + + OPENVINO_ASSERT(self->is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); + + self->perf_metrics().draft_model_metrics.raw_metrics = self->draft_pipeline()->raw_perf_metrics; + uint64_t generate_duration_us = strategy.stop_timer(t_start); + + std::vector results; + results.reserve(all_requests.size()); + + for (size_t rid = 0; rid < all_requests.size(); ++rid) { + const auto& request = all_requests[rid]; + auto cfg = request->get_sampling_parameters(); + const auto& seqs = request->get_finished_sequences(); + size_t num_out = std::min(cfg.num_return_sequences, seqs.size()); + + EncodedGenerationResult result; + result.m_request_id = rid; + result.m_generation_ids.resize(num_out); + result.m_scores.resize(num_out); + result.m_status = main_generations[rid]->get_status(); + + for (size_t i = 0; i < num_out; ++i) { + const auto& seq = seqs[i]; + float score = cfg.is_beam_search() ? + seq->get_beam_search_score(cfg) : + seq->get_cumulative_log_prob(); + const auto& gen_ids = seq->get_generated_ids(); + if (cfg.echo) { + result.m_generation_ids[i] = request->get_prompt_ids(); + } + std::copy(gen_ids.begin(), gen_ids.end(), + std::back_inserter(result.m_generation_ids[i])); + result.m_scores[i] = score; + } + + self->perf_metrics().raw_metrics.generate_durations.clear(); + self->perf_metrics().raw_metrics.generate_durations.emplace_back(generate_duration_us); + self->perf_metrics().num_input_tokens = request->get_prompt_len(); + self->perf_metrics().evaluate_statistics(t_start); + + result.perf_metrics = self->perf_metrics(); + result.extended_perf_metrics = std::make_shared(self->perf_metrics()); + results.push_back(std::move(result)); + } + + OPENVINO_ASSERT(results.size() == input_ids.size()); + return results; +} class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline { protected: @@ -26,8 +149,18 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat void drop_requests(); bool is_requests_empty(); std::vector get_awaiting_requests(); - + std::pair init_speculative_models(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); public: + template + friend std::vector generate_common( + Impl* self, + const std::vector& input_ids, + const std::vector& sampling_params, + const StreamerVariant& streamer, + std::optional> token_type_ids, + GenerateStrategy& strategy); + + SpeculativeDecodingImpl() = default; SpeculativeDecodingImpl(const ov::genai::ModelDesc& main_model_desc, const ov::genai::ModelDesc& draft_model_desc); GenerationHandle add_request(uint64_t request_id, @@ -50,6 +183,16 @@ class ContinuousBatchingPipeline::SpeculativeDecodingImpl : public ContinuousBat const std::optional>>>& position_ids = std::nullopt) override; SpeculativeDecodingMetrics get_speculative_decoding_metrics(); + SDPerModelsPerfMetrics& perf_metrics() { return m_perf_metrics; } + SDPerModelsPerfMetrics const& perf_metrics() const { return m_perf_metrics; } + std::shared_ptr& draft_pipeline() { return m_draft_pipeline; } + std::shared_ptr& main_pipeline() { return m_main_pipeline; } + + Tokenizer& tokenizer() { return m_tokenizer; } + const Tokenizer& tokenizer() const { return m_tokenizer; } + + std::mutex& draft_generations_mutex() { return m_draft_generations_mutex; } + std::map& draft_generations() { return m_draft_generations; } }; -} +} // namespace ov::genai \ No newline at end of file diff --git a/src/cpp/src/speculative_decoding/update_request_structs.hpp b/src/cpp/src/speculative_decoding/update_request_structs.hpp index 68f79268f5..4426372507 100644 --- a/src/cpp/src/speculative_decoding/update_request_structs.hpp +++ b/src/cpp/src/speculative_decoding/update_request_structs.hpp @@ -10,11 +10,17 @@ namespace ov::genai { struct GeneratedSequence { std::vector token_ids; std::vector log_probs; - + // Stores the hidden states tensor associated with the generated sequence. + // This field is used for the "eagle speculative" decoding algorithm, + // where hidden states are required to efficiently validate and extend speculative tokens. + // If not using eagle speculative decoding, this field may remain empty. + ov::Tensor hidden_states; GeneratedSequence(const std::vector& generated_token_ids, - const std::vector& generated_log_probs) : + const std::vector& generated_log_probs, + const ov::Tensor generated_hidden_states = {}) : token_ids(generated_token_ids), - log_probs(generated_log_probs) {}; + log_probs(generated_log_probs), + hidden_states(generated_hidden_states) {}; }; struct UpdateRequestResult { diff --git a/tests/python_tests/samples/conftest.py b/tests/python_tests/samples/conftest.py index 8a011ccfe2..f6e0b0eeeb 100644 --- a/tests/python_tests/samples/conftest.py +++ b/tests/python_tests/samples/conftest.py @@ -143,6 +143,14 @@ "tiny-random-SpeechT5ForTextToSpeech": { "name": "hf-internal-testing/tiny-random-SpeechT5ForTextToSpeech", "convert_args": ["--model-kwargs", json.dumps({"vocoder": "fxmarty/speecht5-hifigan-tiny"})] + }, + "Qwen3-1.7B": { + "name": "Qwen/Qwen3-1.7B", + "convert_args": ["--task", "text-generation-with-past", '--trust-remote-code'] + }, + "qwen3_1.7b_eagle3": { + "name": "AngelSlim/Qwen3-1.7B_eagle3", + "convert_args": ["--task", "text-generation-with-past", "--trust-remote-code", "--eagle3"] } } diff --git a/tests/python_tests/samples/test_speculative_decoding_lm.py b/tests/python_tests/samples/test_speculative_decoding_lm.py index 35a1bb285a..0b4ef570a2 100644 --- a/tests/python_tests/samples/test_speculative_decoding_lm.py +++ b/tests/python_tests/samples/test_speculative_decoding_lm.py @@ -10,6 +10,23 @@ convert_draft_model = convert_model +def _run_spec_case(convert_model, convert_draft_model, sample_args, env): + cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'speculative_decoding_lm') + cpp_command =[cpp_sample, convert_model, convert_draft_model, sample_args] + cpp_result = run_sample(cpp_command, env=env) + + py_script = os.path.join(SAMPLES_PY_DIR, "text_generation/speculative_decoding_lm.py") + py_command = [sys.executable, py_script, convert_model, convert_draft_model, sample_args] + py_result = run_sample(py_command, env=env) + + cpp_sample_ref = os.path.join(SAMPLES_CPP_DIR, 'greedy_causal_lm') + cpp_command_ref = [cpp_sample_ref, convert_model, sample_args] + cpp_result_ref = run_sample(cpp_command_ref, env=env) + + assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" + assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" + return cpp_result, py_result, cpp_result_ref + class TestSpeculativeDecodingLM: @pytest.mark.llm @pytest.mark.samples @@ -26,22 +43,27 @@ def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model pytest.xfail("Ticket 173586") env = os.environ.copy() env["OPENVINO_LOG_LEVEL"] = "0" - # Test CPP sample - cpp_sample = os.path.join(SAMPLES_CPP_DIR, 'speculative_decoding_lm') - cpp_command =[cpp_sample, convert_model, convert_draft_model, sample_args] - cpp_result = run_sample(cpp_command, env=env) - - # Test Python sample - py_script = os.path.join(SAMPLES_PY_DIR, "text_generation/speculative_decoding_lm.py") - py_command = [sys.executable, py_script, convert_model, convert_draft_model, sample_args] - py_result = run_sample(py_command, env=env) - - # Greedy decoding - cpp_sample_ref = os.path.join(SAMPLES_CPP_DIR, 'greedy_causal_lm') - cpp_command_ref = [cpp_sample_ref, convert_model, sample_args] - cpp_result_ref = run_sample(cpp_command_ref, env=env) - - # Compare results - assert cpp_result_ref.stdout.strip() in py_result.stdout.strip(), "Python and CPP results should match" - assert cpp_result_ref.stdout.strip() in cpp_result.stdout.strip(), "Greedy and speculative decoding results should match" + _run_spec_case(convert_model, convert_draft_model, sample_args, env) + +test_prompt = """Code: +def add(a, b): + return a + b +Question: Can you please add 2 and 3 +A:""" +class TestEagle3SpeculativeDecodingLM: + @pytest.mark.llm + @pytest.mark.samples + @pytest.mark.parametrize( + "convert_model, convert_draft_model, sample_args", + [ + pytest.param("Qwen3-1.7B", "qwen3_1.7b_eagle3", test_prompt, marks=pytest.mark.skip(reason = 'CVS-171947, CVS-171943, CVS-174959')), + ], + indirect=["convert_model", "convert_draft_model"], + ) + def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model, sample_args): + if sys.platform == 'darwin': + pytest.xfail("Ticket 173586") + env = os.environ.copy() + env["OPENVINO_LOG_LEVEL"] = "0" + _run_spec_case(convert_model, convert_draft_model, sample_args, env) \ No newline at end of file diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index b32ba5e8e7..975952406c 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -16,8 +16,9 @@ from utils.generation_config import get_greedy, get_beam_search, \ get_multinomial_all_parameters, get_multinomial_temperature_and_num_return_sequence, \ get_multinomial_temperature_and_top_k, get_multinomial_temperature, get_multinomial_temperature_and_top_p -from utils.hugging_face import download_and_convert_model -from utils.ov_genai_pipelines import create_ov_pipeline, create_ov_cb_pipeline, PipelineType, dict_to_scheduler_config, generate_and_compare, prepare_generation_config_by_pipe_type, GenerationChatInputsType +from utils.hugging_face import download_and_convert_model, run_hugging_face +from utils.ov_genai_pipelines import create_ov_pipeline, create_ov_cb_pipeline, PipelineType, dict_to_scheduler_config, generate_and_compare, prepare_generation_config_by_pipe_type, convert_decoded_results_to_generation_result, GenerationChatInputsType +from utils.comparation import compare_generation_results from data.models import get_chat_models_list from data.test_dataset import get_test_dataset @@ -478,21 +479,44 @@ def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_c return pipe, prompt, generation_config -def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType): +def run_extended_perf_metrics_collection(model_id, generation_config: GenerationConfig, prompt: str, pipeline_type: PipelineType, draft_model_id: str): _, _, model_path = download_and_convert_model(model_id) - ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type) + draft_model_path = None + if draft_model_id is not None: + _,_, draft_model_path = download_and_convert_model(draft_model_id) + ov_pipe = create_ov_pipeline(model_path, pipeline_type=pipeline_type, draft_model_path = draft_model_path) return ov_pipe.generate([prompt], generation_config).extended_perf_metrics +eagle_models_and_input = [ + ("Qwen/Qwen3-1.7B", "AngelSlim/Qwen3-1.7B_eagle3", """Code: +def add(a, b): + return a + b +Question: Can you please add 2 and 3 +A:""")] + +speculative_cases = [ + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None, "Why is the Sun yellow?"), + eagle_models_and_input[0], +] @pytest.mark.parametrize("pipeline_type", [PipelineType.PAGED_ATTENTION, PipelineType.SPECULATIVE_DECODING]) -def test_speculative_decoding_extended_perf_metrics(pipeline_type): +@pytest.mark.parametrize("main_model_id,draft_model_id, prompt", speculative_cases) +@pytest.mark.precommit +def test_speculative_decoding_extended_perf_metrics(pipeline_type, main_model_id, draft_model_id, prompt): import time start_time = time.perf_counter() - model_id : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) - extended_perf_metrics = run_extended_perf_metrics_collection(model_id, generation_config, "Why is the Sun yellow?", pipeline_type) - total_time = (time.perf_counter() - start_time) * 1000 + extended_perf_metrics = None + if draft_model_id is None: + generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + extended_perf_metrics = run_extended_perf_metrics_collection(main_model_id, generation_config, prompt, pipeline_type, draft_model_id) + total_time = (time.perf_counter() - start_time) * 1000 + else: + if (pipeline_type == PipelineType.SPECULATIVE_DECODING): + generation_config = GenerationConfig(do_sample=False, max_new_tokens=20, ignore_eos=True, num_assistant_tokens=5) + extended_perf_metrics = run_extended_perf_metrics_collection(main_model_id, generation_config, prompt, pipeline_type, draft_model_id) + total_time = (time.perf_counter() - start_time) * 1000 + if (pipeline_type == PipelineType.SPECULATIVE_DECODING): assert not extended_perf_metrics is None assert not extended_perf_metrics.main_model_metrics is None @@ -530,3 +554,31 @@ def test_speculative_decoding_extended_perf_metrics(pipeline_type): assert std_gen_duration == 0 else: assert extended_perf_metrics is None + +devices = [ + ('CPU', 'CPU') +] +@pytest.mark.parametrize("main_model,draft_model,prompt", eagle_models_and_input) +@pytest.mark.parametrize("main_device,draft_device", devices) +@pytest.mark.precommit +def test_eagle3_sd_string_inputs(main_model, main_device, draft_model, draft_device, prompt): + # Download and convert model: + main_opt_model, main_hf_tokenizer, main_model_path = download_and_convert_model(main_model) + __, __, draft_model_path = download_and_convert_model(draft_model) + + # Create OpenVINO GenAI pipeline: + + ov_pipe = create_ov_pipeline(main_model_path, pipeline_type = PipelineType.SPECULATIVE_DECODING, draft_model_path = draft_model_path) + + # Run reference HF model: + ov_generation_config = GenerationConfig(max_new_tokens=20) + ref_gen_results = run_hugging_face(main_opt_model, main_hf_tokenizer, [prompt], ov_generation_config) + + # Run OpenVINO GenAI pipeline: + ov_decoded_results = ov_pipe.generate([prompt], ov_generation_config) + ov_gen_results = convert_decoded_results_to_generation_result(ov_decoded_results, 1, 1, False) + + del ov_pipe + + # Compare results: + compare_generation_results([prompt], ref_gen_results, ov_gen_results, ov_generation_config) \ No newline at end of file diff --git a/tests/python_tests/utils/hugging_face.py b/tests/python_tests/utils/hugging_face.py index ec2535dcbe..dcfe3dc7cc 100644 --- a/tests/python_tests/utils/hugging_face.py +++ b/tests/python_tests/utils/hugging_face.py @@ -166,9 +166,14 @@ def run_hugging_face( # download HF model or read converted model def get_huggingface_models(model_id: str | Path, model_class: Type[OVModel], local_files_only=False): - hf_tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, local_files_only=local_files_only)) - opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) - return opt_model, hf_tokenizer + if "eagle3" not in str(model_id).lower(): + hf_tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id, local_files_only=local_files_only)) + opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) + return opt_model, hf_tokenizer + else: + hf_tokenizer = None + opt_model = retry_request(lambda: model_class.from_pretrained(model_id, eagle3=True, export=isinstance(model_id, str), compile=False, load_in_8bit=False, ov_config=get_default_llm_properties(), local_files_only=local_files_only)) + return opt_model, hf_tokenizer def convert_and_save_tokenizer(hf_tokenizer : AutoTokenizer, @@ -192,9 +197,10 @@ def convert_models(opt_model : OVModelForCausalLM, opt_model.config.save_pretrained(models_path) # to store tokenizer config jsons with special tokens - hf_tokenizer.save_pretrained(models_path) - # convert tokenizers as well - convert_and_save_tokenizer(hf_tokenizer, models_path, **tokenizer_kwargs) + if hf_tokenizer: + hf_tokenizer.save_pretrained(models_path) + # convert tokenizers as well + convert_and_save_tokenizer(hf_tokenizer, models_path, **tokenizer_kwargs) def download_and_convert_model(model_id: str, **tokenizer_kwargs):