diff --git a/tools/who_what_benchmark/whowhatbench/model_loaders.py b/tools/who_what_benchmark/whowhatbench/model_loaders.py index 160ae28bd4..09b6f39c0c 100644 --- a/tools/who_what_benchmark/whowhatbench/model_loaders.py +++ b/tools/who_what_benchmark/whowhatbench/model_loaders.py @@ -1,3 +1,4 @@ +from pathlib import Path import logging import json import torch @@ -76,6 +77,17 @@ def load_text_genai_pipeline(model_dir, device="CPU", ov_config=None, **kwargs): ov_adapter = openvino_genai.Adapter(adapter) adapter_config.add(ov_adapter, alpha) + draft_model_path = kwargs.get("draft_model", '') + if draft_model_path: + if not Path(draft_model_path).exists(): + raise RuntimeError(f"Error: Draft model path does not exist: {draft_model_path}") + draft_device = kwargs.get("draft_device", None) or device + draft_model_load_kwargs = ( + {"scheduler_config": get_scheduler_config_genai(kwargs["draft_cb_config"])} + if kwargs["draft_cb_config"] is not None else {} + ) + ov_config["draft_model"] = openvino_genai.draft_model(draft_model_path, draft_device.upper(), **draft_model_load_kwargs) + is_continuous_batching = kwargs.get("cb_config", None) is not None if is_continuous_batching: diff --git a/tools/who_what_benchmark/whowhatbench/text_evaluator.py b/tools/who_what_benchmark/whowhatbench/text_evaluator.py index 72a1dcc632..e033bd2be0 100644 --- a/tools/who_what_benchmark/whowhatbench/text_evaluator.py +++ b/tools/who_what_benchmark/whowhatbench/text_evaluator.py @@ -36,7 +36,9 @@ def __init__( seqs_per_request=None, use_chat_template=None, long_prompt=False, - empty_adapters=False + empty_adapters=False, + num_assistant_tokens=0, + assistant_confidence_threshold=0.0 ) -> None: assert ( base_model is not None or gt_data is not None @@ -53,6 +55,8 @@ def __init__( self.seqs_per_request = seqs_per_request self.generation_fn = gen_answer_fn self.use_chat_template = use_chat_template + self.num_assistant_tokens = num_assistant_tokens + self.assistant_confidence_threshold = assistant_confidence_threshold if self.generation_config is not None: assert self.seqs_per_request is not None self.empty_adapters = empty_adapters @@ -135,7 +139,8 @@ def worst_examples(self, top_k: int = 5, metric="similarity"): return res def _generate_data(self, model, gen_answer_fn=None, generation_config=None): - def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False, empty_adapters=False): + def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False, empty_adapters=False, + num_assistant_tokens=0, assistant_confidence_threshold=0.0): is_awq = getattr(model, "is_awq", None) is not None device = "cpu" if hasattr(model, "device"): @@ -196,7 +201,9 @@ def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, self.max_new_tokens, self._crop_question, self.use_chat_template, - empty_adapters=self.empty_adapters + self.empty_adapters, + self.num_assistant_tokens, + self.assistant_confidence_threshold ) ) else: diff --git a/tools/who_what_benchmark/whowhatbench/wwb.py b/tools/who_what_benchmark/whowhatbench/wwb.py index 709ae91575..432090c5a6 100644 --- a/tools/who_what_benchmark/whowhatbench/wwb.py +++ b/tools/who_what_benchmark/whowhatbench/wwb.py @@ -232,6 +232,35 @@ def parse_args(): "If the base/target model is a local path, gguf-file should be just the filename (e.g., 'model.gguf'). " "If the base/target model is a HuggingFace model ID, gguf-file should be a relative path.", ) + parser.add_argument( + "--draft-model", + default=None, + help="Path to draft model folder including IR files for Speculative decoding generation.", + ) + parser.add_argument( + "--draft-device", + type=str, + default=None, + help="Inference device for Speculative decoding of draft model, e.g. 'CPU', 'GPU'.", + ) + parser.add_argument( + "--draft-cb-config", + type=str, + default=None, + help="Path to file with Continuous Batching Scheduler settings or dict for Speculative decoding of draft model", + ) + parser.add_argument( + "--num-assistant-tokens", + type=int, + default=None, + help="Config option num_assistant_tokens for Speculative decoding and Prompt Lookup decoding.", + ) + parser.add_argument( + "--assistant-confidence-threshold", + type=float, + default=None, + help="Config option assistant_confidence_threshold for Speculative decoding.", + ) return parser.parse_args() @@ -387,16 +416,26 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str: return "".join(output) -def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False, empty_adapters=False): +def genai_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False, empty_adapters=False, + num_assistant_tokens=0, assistant_confidence_threshold=0.0): kwargs = {} if empty_adapters: import openvino_genai kwargs["adapters"] = openvino_genai.AdapterConfig() - return model.generate(question, do_sample=False, max_new_tokens=max_new_tokens, apply_chat_template=use_chat_template, **kwargs) + return model.generate( + question, + do_sample=False, + max_new_tokens=max_new_tokens, + apply_chat_template=use_chat_template, + num_assistant_tokens=num_assistant_tokens, + assistant_confidence_threshold=assistant_confidence_threshold, + **kwargs, + ) -def llamacpp_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False): +def llamacpp_gen_text(model, tokenizer, question, max_new_tokens, skip_question, use_chat_template=False, num_assistant_tokens=0, + assistant_confidence_threshold=0.0): if use_chat_template: output = model.create_chat_completion(messages=[{"role": "user", "content": question}], max_tokens=max_new_tokens, temperature=0.0) text = output["choices"][0]["message"]["content"] @@ -523,6 +562,14 @@ def create_evaluator(base_model, args): gen_answer_fn=gen_answer_fn, use_chat_template=use_chat_template, long_prompt=args.long_prompt, + num_assistant_tokens=( + int(args.num_assistant_tokens) + if args.num_assistant_tokens is not None else 0 + ), + assistant_confidence_threshold=( + float(args.assistant_confidence_threshold) + if args.assistant_confidence_threshold is not None else 0.0 + ), ) elif task == "text-to-image": return EvaluatorCLS( @@ -725,6 +772,14 @@ def main(): kwargs["embeds_normalize"] = args.embeds_normalize kwargs["embeds_padding_side"] = args.embeds_padding_side + if args.draft_model is not None: + kwargs["draft_model"] = args.draft_model + kwargs["draft_device"] = args.draft_device + draft_cb_config = None + if args.draft_cb_config is not None: + draft_cb_config = read_cb_config(args.draft_cb_config) + kwargs["draft_cb_config"] = draft_cb_config + if args.gt_data and os.path.exists(args.gt_data): evaluator = create_evaluator(None, args) else: