-
Notifications
You must be signed in to change notification settings - Fork 291
[WWB] Add eagle3 pipeline #2812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 8 commits
337e4cd
96c3672
1f0eba8
7493bea
7e4062c
acc068a
b2ac5b7
abfb6ef
bb33c7c
e48a0e1
30395a5
246dc4c
eb83451
1504519
311e026
2f637f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -232,6 +232,35 @@ def parse_args(): | |
| "If the base/target model is a local path, gguf-file should be just the filename (e.g., 'model.gguf'). " | ||
| "If the base/target model is a HuggingFace model ID, gguf-file should be a relative path.", | ||
| ) | ||
| parser.add_argument( | ||
| "--draft-model", | ||
| default=None, | ||
| help="Path to draft model folder including IR files for Speculative decoding generation.", | ||
| ) | ||
| parser.add_argument( | ||
| "--draft-device", | ||
| type=str, | ||
| default=None, | ||
| help="Inference device for Speculative decoding of draft model, e.g. 'CPU', 'GPU'.", | ||
| ) | ||
| parser.add_argument( | ||
| "--draft-cb-config", | ||
| type=str, | ||
| default=None, | ||
| help="Path to file with Continuous Batching Scheduler settings or dict for Speculative decoding of draft model", | ||
| ) | ||
| parser.add_argument( | ||
| "--num-assistant-tokens", | ||
| type=int, | ||
| default=None, | ||
| help="Config option num_assistant_tokens for Speculative decoding and Prompt Lookup decoding.", | ||
| ) | ||
| parser.add_argument( | ||
| "--assistant-confidence-threshold", | ||
| type=float, | ||
| default=None, | ||
| help="Config option assistant_confidence_threshold for Speculative decoding.", | ||
| ) | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
|
|
@@ -387,11 +416,13 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str: | |
| return "".join(output) | ||
|
|
||
|
|
||
| def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False): | ||
| return model.generate(question, do_sample=False, max_new_tokens=max_new_tokens, apply_chat_template=use_chat_template) | ||
| def genai_gen_text(model, tokenizer, question, gen_config, skip_question): | ||
| return model.generate(question, gen_config) | ||
|
|
||
|
|
||
| def llamacpp_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False): | ||
| def llamacpp_gen_text(model, tokenizer, question, gen_config, skip_question): | ||
| max_new_tokens = gen_config.max_new_tokens | ||
| use_chat_template = gen_config.apply_chat_template | ||
| if use_chat_template: | ||
| output = model.create_chat_completion(messages=[{"role": "user", "content": question}], max_tokens=max_new_tokens, temperature=0.0) | ||
| text = output["choices"][0]["message"]["content"] | ||
|
|
@@ -491,6 +522,7 @@ def create_evaluator(base_model, args): | |
| task = args.model_type | ||
|
|
||
| try: | ||
| import openvino_genai | ||
sunxiaoxia2022 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| EvaluatorCLS = EVALUATOR_REGISTRY[task] | ||
| prompts = load_prompts(args) | ||
|
|
||
|
|
@@ -507,6 +539,21 @@ def create_evaluator(base_model, args): | |
| use_chat_template = ( | ||
| tokenizer is not None and tokenizer.chat_template is not None and not args.omit_chat_template | ||
| ) | ||
|
|
||
| gen_config = openvino_genai.GenerationConfig() | ||
|
||
| gen_config.max_new_tokens = 128 | ||
sunxiaoxia2022 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
peterchen-intel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| gen_config.apply_chat_template = use_chat_template | ||
| gen_config.do_sample = False | ||
| if args.draft_model is not None: | ||
| config_info = "Speculative decoding config: " | ||
| if args.num_assistant_tokens is not None: | ||
| gen_config.num_assistant_tokens = int(args.num_assistant_tokens) | ||
| config_info += f" num_assistant_tokens {gen_config.num_assistant_tokens}" | ||
| if args.assistant_confidence_threshold is not None: | ||
| gen_config.assistant_confidence_threshold = float(args.assistant_confidence_threshold) | ||
| config_info += f" assistant_confidence_threshold {gen_config.assistant_confidence_threshold}" | ||
| logger.info(config_info) | ||
|
|
||
| return EvaluatorCLS( | ||
| base_model=base_model, | ||
| gt_data=args.gt_data, | ||
|
|
@@ -516,6 +563,8 @@ def create_evaluator(base_model, args): | |
| num_samples=args.num_samples, | ||
| language=args.language, | ||
| gen_answer_fn=gen_answer_fn, | ||
| generation_config=gen_config, | ||
| seqs_per_request=1, | ||
| use_chat_template=use_chat_template, | ||
xufang-lisa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| long_prompt=args.long_prompt, | ||
| ) | ||
|
|
@@ -715,11 +764,21 @@ def main(): | |
| kwargs["alphas"] = args.alphas | ||
| else: | ||
| kwargs["alphas"] = [1.0] * len(args.adapters) | ||
|
|
||
| kwargs["empty_adapters"] = args.empty_adapters | ||
| kwargs["embeds_pooling"] = args.embeds_pooling_type | ||
| kwargs["embeds_normalize"] = args.embeds_normalize | ||
| kwargs["embeds_padding_side"] = args.embeds_padding_side | ||
|
|
||
| if args.draft_model is not None: | ||
| kwargs["draft_model"] = args.draft_model | ||
| if args.draft_device is not None: | ||
| kwargs["draft_device"] = args.draft_device | ||
| if args.draft_cb_config is not None: | ||
| kwargs["draft_cb_config"] = read_cb_config(args.draft_cb_config) | ||
| else: | ||
| kwargs["draft_cb_config"] = None | ||
|
|
||
| if args.gt_data and os.path.exists(args.gt_data): | ||
| evaluator = create_evaluator(None, args) | ||
| else: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.