Skip to content

Conversation

@songbell
Copy link

@songbell songbell commented Sep 17, 2025

Copilot AI review requested due to automatic review settings September 17, 2025 07:28
@github-actions github-actions bot added category: continuous batching Continuous batching category: LLM LLM pipeline (stateful, static) category: sampling Sampling / Decoding algorithms category: speculative decoding Speculative decoding category: LoRA Low rank adapters category: cmake / build Cmake scripts category: LLM samples GenAI LLM samples category: CPP API Changes in GenAI C++ public headers no-match-files labels Sep 17, 2025
@songbell songbell changed the title eagle impl with top-1 proposal eagle3 cb impl with top-1 proposal Sep 17, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements Eagle speculative decoding functionality for top-1 proposal generation. The implementation adds support for Eagle3 mode, which enables accelerated text generation through speculative decoding with hidden state sharing between main and draft models.

Key changes include:

  • Added Eagle decoding implementation with model transformation pipelines for hidden state extraction
  • Integrated safetensor parsing for Eagle3 configuration data (d2t mappings)
  • Extended continuous batching pipeline to support Eagle mode with hidden state management

Reviewed Changes

Copilot reviewed 24 out of 25 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tools/continuous_batching/accuracy/continuous_batching_eagle_decoding.cpp New Eagle decoding accuracy test tool
src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp Added Eagle decoding class definitions and model transformation passes
src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp Core Eagle decoding implementation with model transformations
src/cpp/src/continuous_batching/pipeline.cpp Integration of Eagle mode into pipeline construction
src/cpp/src/continuous_batching/model_runner.hpp Added hidden state management functionality
samples/cpp/text_generation/eagle_speculative_lm.cpp New Eagle speculative decoding sample
src/cpp/src/safe_tensor_wrapper.hpp New safetensor parsing utilities
Comments suppressed due to low confidence (1)

src/cpp/src/continuous_batching/model_runner.hpp:1

  • This appears to be modifying the token index without bounds checking on the d2t array. Add bounds checking to prevent potential buffer overflow.
// Copyright (C) 2023-2025 Intel Corporation

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@github-actions github-actions bot added the category: llm_bench Label for tool/llm_bench folder label Sep 18, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 28 out of 29 changed files in this pull request and generated 10 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

timeout: 240
- name: 'LLM & VLM'
cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/'
cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_eagle3.py --override-ini cache_dir=/mount/caches/pytest/'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move it to a separate matrix element. That way it will be run on a separate runner avoiding test slow down.

As a side effect you should be able to install appropriate optimum-intel here and remove @pytest.mark.skip(reason="CVS-174959 enable model conversion for eagle3 and enable the test")

assert std_gen_duration == 0
else:
assert extended_perf_metrics is None
assert extended_perf_metrics is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the line back :)

@MaximProshin MaximProshin mentioned this pull request Oct 20, 2025
3 tasks
@apaniukov
Copy link
Contributor

I am trying to run the tests with this command

pytest tests/python_tests/test_eagle3.py

@songbell could you share the environment?

My current env:
GenAI from this branch: https://github.com/songbell/openvino.genai/tree/bell/eagle_cb_impl
Optimum-intel from this branch: https://github.com/xufang-lisa/optimum-intel/tree/xufang/add_eagle3_draft_model_conversion

With transformers-4.55.4 the original tokenizer cannot be loaded:

.venv/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:2069: in from_pretrained
    return cls._from_pretrained(
.venv/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:2107: in _from_pretrained
    slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
.venv/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:2315: in _from_pretrained
    tokenizer = cls(*init_inputs, **init_kwargs)
.venv/lib/python3.10/site-packages/transformers/models/llama/tokenization_llama.py:171: in __init__
    self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
.venv/lib/python3.10/site-packages/transformers/models/llama/tokenization_llama.py:198: in get_spm_processor
    tokenizer.Load(self.vocab_file)
.venv/lib/python3.10/site-packages/sentencepiece/__init__.py:961: in Load
    return self.LoadFromFile(model_file)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <sentencepiece.SentencePieceProcessor; proxy of <Swig Object of type 'sentencepiece::SentencePieceProcessor *' at 0x71b823d09a70> >, arg = None

    def LoadFromFile(self, arg):
>       return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
E       TypeError: not a string

.venv/lib/python3.10/site-packages/sentencepiece/__init__.py:316: TypeError

With newer transformers-4.57.1:

tests/python_tests/utils/hugging_face.py:201: in download_and_convert_model
    return _download_and_convert_model(model_id, OVModelForCausalLM, **tokenizer_kwargs)
tests/python_tests/utils/hugging_face.py:233: in _download_and_convert_model
    opt_model, hf_tokenizer = get_huggingface_models(model_id, model_class, local_files_only=False)
tests/python_tests/utils/hugging_face.py:170: in get_huggingface_models
    opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, trust_remote_code=isinstance(model_id, str), ov_config=get_default_llm_properties(), local_files_only=local_files_only))
tests/python_tests/utils/network.py:37: in retry_request
    return func()
tests/python_tests/utils/hugging_face.py:170: in <lambda>
    opt_model = retry_request(lambda: model_class.from_pretrained(model_id, export=isinstance(model_id, str), compile=False, load_in_8bit=False, trust_remote_code=isinstance(model_id, str), ov_config=get_default_llm_properties(), local_files_only=local_files_only))
.venv/lib/python3.10/site-packages/optimum/intel/openvino/modeling_base.py:504: in from_pretrained
    return super().from_pretrained(
.venv/lib/python3.10/site-packages/optimum/modeling_base.py:407: in from_pretrained
    return from_pretrained_method(
.venv/lib/python3.10/site-packages/optimum/intel/openvino/modeling_decoder.py:345: in _export
    main_export(
.venv/lib/python3.10/site-packages/optimum/exporters/openvino/__main__.py:562: in main_export
    submodel_paths = export_from_model(
.venv/lib/python3.10/site-packages/optimum/exporters/openvino/convert.py:745: in export_from_model
    export_models(
.venv/lib/python3.10/site-packages/optimum/exporters/openvino/convert.py:514: in export_models
    export(
.venv/lib/python3.10/site-packages/optimum/exporters/openvino/convert.py:216: in export
    return export_pytorch(
.venv/lib/python3.10/site-packages/optimum/exporters/openvino/convert.py:422: in export_pytorch
    ov_model = convert_model(
../openvino_toolkit_ubuntu22_2025.4.0.dev20251017_x86_64/python/openvino/tools/ovc/convert.py:105: in convert_model
    ov_model, _ = _convert(cli_parser, params, True)
../openvino_toolkit_ubuntu22_2025.4.0.dev20251017_x86_64/python/openvino/tools/ovc/convert_impl.py:578: in _convert
    raise e
../openvino_toolkit_ubuntu22_2025.4.0.dev20251017_x86_64/python/openvino/tools/ovc/convert_impl.py:518: in _convert
    ov_model = driver(argv, {"conversion_parameters": non_default_params})
../openvino_toolkit_ubuntu22_2025.4.0.dev20251017_x86_64/python/openvino/tools/ovc/convert_impl.py:256: in driver
    ov_model = moc_emit_ir(prepare_ir(argv), argv)
../openvino_toolkit_ubuntu22_2025.4.0.dev20251017_x86_64/python/openvino/tools/ovc/convert_impl.py:195: in prepare_ir
    ov_model = moc_pipeline(argv, moc_front_end)
../openvino_toolkit_ubuntu22_2025.4.0.dev20251017_x86_64/python/openvino/tools/ovc/moc_frontend/pipeline.py:142: in moc_pipeline
    ov_model = moc_front_end.convert(input_model)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <FrontEnd 'pytorch'>, model = <openvino._pyopenvino.InputModel object at 0x79b16aeefe20>

    def convert(self, model: Union[Model, InputModel]) -> Model:
>       converted_model = super().convert(model)
E       openvino._pyopenvino.OpConversionFailure: Check 'is_conversion_successful' failed at src/frontends/pytorch/src/frontend.cpp:180:
E       FrontEnd API failed with OpConversionFailure:
E       Model wasn't fully converted. Failed operations detailed log:
E       -- aten::cat with a message:
E       Exception happened during conversion of operation aten::cat with schema aten::cat(Tensor[] tensors, int dim=0) -> Tensor
E       Check 'is_axis_valid(axis, r)' failed at src/core/src/validation_util.cpp:332:
E       While validating node 'opset1::Concat Concat_68 (opset1::Convert aten::to/Convert[0]:f32[0], util::PtFrameworkNode prim::TupleUnpack[1]:f32[?,?,?,?]) -> (dynamic[...])' with friendly_name 'Concat_68':
E       Axis -2 out of the tensor rank range [-1, 0].
E       
E       -- prim::ListConstruct with a message:
E       Exception happened during conversion of operation prim::ListConstruct with schema (no schema)
E       Check '(c_node)' failed at src/frontends/pytorch/src/op/list_construct.cpp:25:
E       FrontEnd API failed with OpConversionFailure:
E       [PyTorch Frontend] Translation for prim::ListConstruct support only constant inputs
E       
E       -- prim::ListUnpack with a message:
E       Exception happened during conversion of operation prim::ListUnpack with schema (no schema)
E       Exception from src/core/src/node.cpp:593:
E       node index is out of range
E       
E       Summary:
E       -- normalize step failed with: Exception from src/core/src/pass/graph_rewrite.cpp:298:
E       [ov::frontend::pytorch::pass::AtenCatToConcat] END: node: util::PtFrameworkNode aten::cat (util::PtFrameworkNode prim::ListConstruct[0]:dynamic[...], opset1::Constant 111[0]:i64[]) -> (f32[?,?,?,?]) CALLBACK HAS THROWN: Check 'is_axis_valid(axis, r)' failed at src/core/src/validation_util.cpp:332:
E       While validating node 'opset1::Concat Concat_34833 (opset1::Convert aten::to/Convert[0]:f32[0], util::PtFrameworkNode prim::TupleUnpack[0]:f32[?,?,?,?]) -> (dynamic[...])' with friendly_name 'Concat_34833':
E       Axis -2 out of the tensor rank range [-1, 0].
E       
E       
E       
E       -- No conversion rule found for operations: prim::TupleConstruct, prim::TupleUnpack
E       -- Conversion is failed for: aten::cat, prim::ListConstruct, prim::ListUnpack

../openvino_toolkit_ubuntu22_2025.4.0.dev20251017_x86_64/python/openvino/frontend/frontend.py:18: OpConversionFailure

@pytest.mark.parametrize("main_device,draft_device", devices)
@pytest.mark.precommit
@pytest.mark.skip(reason="CVS-174959 enable model conversion for eagle3 and enable the test")
def test_eagle3_sd_extended_perf_metrics(main_model, main_device, draft_model, draft_device, prompt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is different from the test https://github.com/openvinotoolkit/openvino.genai/blob/master/tests/python_tests/test_continuous_batching.py#L486 ?
This is a test of performance metrics, not that the pipeline runs for an acceptable amount of time. If you still need to test performance metrics, you can parameterize the original test

],
indirect=["convert_model", "convert_draft_model"],
)
def test_sample_speculative_decoding_lm(self, convert_model, convert_draft_model, sample_args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please, parameterize the TestSpeculativeDecodingLM test or move the common code into a separate function

}
if (config.find("dt_mapping_path") != config.end()) {
eagle_rt_info.dt_mapping_table = config.at("dt_mapping_path").as<std::filesystem::path>();
eagle_rt_info.dt_mapping_table = eagle_rt_info.dt_mapping_table / "eagle3.safetensors";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimum-cli tool must not generate safetensors. This is not its responsibility at all.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

problem need to solve:
the original d2t is wrapped in https://huggingface.co/Tengyunw/qwen3_8b_eagle3/blob/main/pytorch_model.bin
and we only need this d2t in genai for draft model to generate correct tokens.

current solution:
as in this PR, we split this part from pytorch_model.bin, and use it separately in genAI

@rkazants, this d2t tensor is also not a part of model representation, or do you think we still can wrap this part into IR ?
cc @wangleis

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Please, educate me on eagle 3 algo and explain why this third model represented as eagle3.safetensors is needed. That's something new compared to the arch review.
  2. Can .safetensors be replaced with IR?
  3. Is it possible to fuse that third model into draft model as a model state and activate it as needed?
  4. Is it possible to fuse that third model at runtime into the main model?

Copy link
Collaborator

@peterchen-intel peterchen-intel Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wovchena @rkazants It has nothing to do with safetensors, but a wrong suffix is used. It is a plain binary includes some offset values for mapping from draft model vocabulary(smaller size) to base/target model's vocabulary(bigger size). Is it OK to save the offset values in a *.bin file since they are used in GenAI pipeline, not in inference?

Copy link

@wangleis wangleis Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dt_mapping_table is part of draft model training implementation detail. So it is not part of arch review.

If the main concern is format of dt_mapping_table, option 2 and 3 in @Wovchena proposal is potential solution.

May I know if either option 2 or 3 is ok for you, @Wovchena and @rkazants?

Copy link
Collaborator

@rkazants rkazants Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @wangleis,

I think that dt_mapping_table should be a part of separate (dedicated) IR. That is because it will be fused in the main model at runtime and there will be no relation to the draft model in such case.
Fusing of that third model to the draft model looks a bit complicated to implement than fusing to the main model because we should differentiate first and upcoming generations in the speculative decoding and use dt mapping table only for the first iteration.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rkazants,

Thanks for comments. We would like to keep main model no change in eagle3 pipeline. So we will update PR to create new IR for dt_mapping_table.

@Wovchena May I know if you have more comments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is dt_mapping_table used in the speculative decoding pipeline ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in draft sampling, to adjust draft tokens, due to draft is using a smaller LM head, see reference:
https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py#L712

Copy link
Collaborator

@rkazants rkazants left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the concern is about safetensors that should not be generated by optimum-cli tool.
optimum-cli generates IRs of different parts of genai pipeline.
We should re-consider solution to avoid safetensors.

@MaximProshin
Copy link
Collaborator

WWB Similarity results: #2812 (comment)

@moslex moslex added the priority: high High piority label Oct 21, 2025
@github-actions github-actions bot removed the category: LoRA Low rank adapters label Oct 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: continuous batching Continuous batching category: CPP API Changes in GenAI C++ public headers category: GGUF GGUF file reader category: GHA CI based on Github actions category: llm_bench Label for tool/llm_bench folder category: LLM samples GenAI LLM samples category: LLM LLM pipeline (stateful, static) category: sampling Sampling / Decoding algorithms category: speculative decoding Speculative decoding Code Freeze no-match-files priority: high High piority

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants