diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0af2c681..ba7fcbe47 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.5.2 + rev: v0.12.7 hooks: # Run the linter. - id: ruff diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index fd7ef03ff..cf9cbcacc 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -10,7 +10,7 @@ from collections import deque from dataclasses import dataclass from time import perf_counter -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import transformers @@ -18,7 +18,9 @@ from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import padding_check_and_fix +from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger +from QEfficient.utils.sampler_utils import validate_sampler_inputs @dataclass @@ -322,6 +324,9 @@ def cloud_ai_100_exec_kv( automation=False, prompt_to_lora_id_mapping: Optional[List[int]] = None, is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, ): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. @@ -342,6 +347,15 @@ def cloud_ai_100_exec_kv( :Write_io_dir (str): Path to write the input and output files. ``Defaults to None``. :automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``. :prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter. + :include_sampler (bool, default=False): Enable/Disable sampling of next tokens. + :return_pdfs (bool, default=False): Return probability distributions along with sampled + next tokens. For Speculative Decoding Target Language Model, + `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative + Decoding Draft Language Model and `return_pdfs`=False for regular model. + sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend. + The dictionary should contain the following keys: + `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, + `min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1). Returns: :CloudAI100ExecInfo: Object holding execution output and performance details. @@ -372,6 +386,9 @@ def cloud_ai_100_exec_kv( write_io_dir=write_io_dir, full_batch_size=full_batch_size, is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, ) if full_batch_size is None: exec_info = [ @@ -411,14 +428,24 @@ def __init__( enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: Optional[int] = None, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._ctx_len = ctx_len self._write_io_dir = write_io_dir self.is_tlm = is_tlm + self.return_pdfs = return_pdfs + self.sampling_params = sampling_params # Load QPC self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + # Validate sampler inputs for On-Device Sampling + self.include_sampler = validate_sampler_inputs( + session_inputs=set(self._session.input_names), include_sampler=include_sampler + ) + # Fetch the variables from the QPC self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len() @@ -523,10 +550,17 @@ def _fetch_vocab_size( Returns: vocab_size: The vocabulary size fetched from the session's allowed shapes. """ + key = ( + "probs" + if self.include_sampler and self.return_pdfs + else "next_tokens" + if self.include_sampler + else "logits" + ) if self._session.allowed_shapes: - return [x[self._session.binding_index_map["logits"]] for x in self._session.allowed_shapes][0][1][2] + return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2] - return self._session.bindings[self._session.binding_index_map["logits"]].dims[2] + return self._session.bindings[self._session.binding_index_map[key]].dims[2] def _fetch_generation_len(self, generation_len, max_gen_len): """ @@ -574,6 +608,13 @@ def prepare_decode_inputs(self): decode_inputs["position_ids"] = self.decode_pos_ids if self.batch_index is not None: decode_inputs["batch_index"] = self.batch_index + if self.include_sampler: + decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] + for op in Constants.SAMPLER_OPS: + if self.batch_index is not None: + decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()] + else: + decode_inputs[op] = self.sampling_params[op] if self._prompt_to_lora_id_mapping_decode: if self.full_batch_size: @@ -589,21 +630,24 @@ def prepare_decode_inputs(self): def _fetch_next_token_id(self, outputs): """ - Fetches the next token ID from the model's output logits. - The method identifies the token with the highest probability using argmax along the last dimension. + Fetches the next token ID from the model's output. + Args: - outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size). + outputs (dict): A dictionary containing the model's output. Returns: numpy.ndarray: An array of the next token IDs for each sequence in the batch. """ - logits = outputs["logits"] - if len(logits.shape) == 2: - logits = np.expand_dims(logits, 1) - - # Get output token - next_token_id = logits.argmax(2) - return next_token_id + if self.include_sampler: + if self.return_pdfs: + return outputs["probs"].argmax(2) + else: + return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1]) + else: + logits = outputs["logits"] + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + return logits.argmax(2) def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length): """ @@ -673,6 +717,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): + """ + Sets the sizes of the output buffers. + + Args: + batch_size (int): The batch size. + """ + if self.include_sampler: + if self.return_pdfs: + probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) + self._session.set_buffers({"probs": probs_out_placeholder}) + next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64) + self._session.set_buffers({"next_tokens": next_tokens_out_placeholder}) + else: + logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) + self._session.set_buffers({"logits": logits_out_placeholder}) + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): """ Runs prefill for a given prompt and generation length. @@ -702,9 +763,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i max_gen_len = self._ctx_len - position_ids.max() generation_len = self._fetch_generation_len(generation_len, max_gen_len) - # Set the prefill logic buffer - logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32) - self._session.set_buffers({"logits": logits_out_placeholder}) + # Set the prefill output buffers + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) @@ -714,6 +774,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["batch_index"] = decode_batch_id if self.is_tlm: inputs["num_logits_to_keep"] = np.zeros((1, 1)) + if self.include_sampler: + inputs["last_accepted_output_tokens"] = inputs["input_ids"] + for op in Constants.SAMPLER_OPS: + if decode_batch_id is not None: + inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + else: + inputs[op] = self.sampling_params[op] if self._prompt_to_lora_id_mapping_prefill: if self.full_batch_size: @@ -732,6 +799,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i chunk_inputs["position_ids"] = inputs["position_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] + if self.include_sampler: + chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] outputs = self._session.run(chunk_inputs) if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) @@ -753,11 +822,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): """ - # Set logits placeholder for decode - logits_out_placeholder = np.zeros( - (self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 + # Set output placeholders for decode + self._set_output_buffers( + batch_size=self.full_batch_size, + sequence_length=self._decode_seq_len, ) - self._session.set_buffers({"logits": logits_out_placeholder}) + # Generate flag for tracking progress for each batch ID current_decode_ongoing = np.full((self.full_batch_size, 1), True) @@ -775,10 +845,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): outputs = self._session.run(decode_inputs) # Prepare inputs for next iteration - logits = outputs["logits"] - if len(logits.shape) == 2: - logits = np.expand_dims(logits, 1) - next_token_id = logits.argmax(2) + next_token_id = self._fetch_next_token_id(outputs) for decode_batch_id in range(self.full_batch_size): if ( @@ -800,7 +867,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1) generated_id_current_index[decode_batch_id] = 1 - self._session.set_buffers({"logits": logits_out_placeholder}) + self._set_output_buffers( + batch_size=self.full_batch_size, + sequence_length=self._decode_seq_len, + ) decode_pause_time += perf_counter() - start if self._prompt_to_lora_id_mapping_decode: @@ -817,6 +887,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( next_token_id[decode_batch_id, -1] ) + if self.include_sampler: + decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] generated_id_current_index[decode_batch_id] += 1 @@ -852,10 +924,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform self._write_io_dir = None # Prepare inputs for next iteration - decode_inputs["input_ids"] = outputs["logits"].argmax(2) + decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) decode_inputs["position_ids"][:, -1] += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id + if self.include_sampler: + decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] if finished_sequences.all(): break @@ -905,9 +979,22 @@ def __init__( enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( - tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm + tokenizer=tokenizer, + qpc_path=qpc_path, + full_batch_size=full_batch_size, + ctx_len=ctx_len, + device_id=device_id, + enable_debug_logs=enable_debug_logs, + write_io_dir=write_io_dir, + is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b3d27f3a5..cfb17b64a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1373,6 +1373,7 @@ def __init__( self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs) if self.is_tlm: self.model.qaic_config["return_pdfs"] = True + self.hash_params["qaic_config"] = self.model.qaic_config # Explicitly add `qaic_config` to model hash @property def model_name(self) -> str: @@ -1827,6 +1828,7 @@ def generate( device_id=device_id, generation_len=generation_len, is_tlm=self.is_tlm, + **kwargs, ) else: raise NotImplementedError("Only AI_100 runtime is supported right now via generate API") diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 6bcabf29a..96846e712 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -193,6 +193,9 @@ def sampler_forward( batch_size, spec_length, vocab_size = logits.shape logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D + if batch_index is None: # Regular model execution + batch_index = torch.arange(batch_size).view(-1, 1) + batch_index_reshaped = batch_index.view(-1) # Prefill past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path( diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index cc52658c6..22a11b88f 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -129,6 +129,16 @@ class Constants: MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS + SAMPLER_OPS = { + "repetition_penalties", + "presence_penalties", + "temperatures", + "top_ks", + "top_ps", + "min_ps", + "random_numbers", + } + SAMPLER_INPUTS = SAMPLER_OPS | {"last_accepted_output_tokens"} SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version. SDK_PLATFORM_XML = ( "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version. diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py new file mode 100644 index 000000000..6fb1b326f --- /dev/null +++ b/QEfficient/utils/sampler_utils.py @@ -0,0 +1,58 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import Optional, Set + +from QEfficient.utils.constants import Constants +from QEfficient.utils.logging_utils import logger + + +def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[bool] = None) -> bool: + """ + Validates whether the `QAICInferenceSession` inputs match inputs required for on-device sampling. + + Mandatory Args: + session_inputs (set[str]): Set of input names from `QAICInferenceSession`. + + Optional Args: + include_sampler (bool, default=None): Whether the user explicitly requested sampler support. + + Returns: + True if sampler is supported, False otherwise. + + Raises: + ValueError if partial support is detected or if user intent conflicts with QPC capabilities. + """ + + sampler_inputs = Constants.SAMPLER_INPUTS + count = len(sampler_inputs & session_inputs) + + session_includes_sampler = True + if count == 0: + session_includes_sampler = False + elif count < len(sampler_inputs): + session_includes_sampler = False + raise ValueError( + f"The provided QPC does not have the required number of inputs to run sampling " + f"on the QAIC device (only {count}/{len(sampler_inputs)} inputs provided). Partial " + "sampling support is not available. Please check the QPC and try again." + ) + + # Post-validation consistency checks + if include_sampler and not session_includes_sampler: + logger.warning( + "User entered `include_sampler`=True. But the provided QPC is not compiled " + "to run sampling on the QAIC device. Falling back to the PyTorch backend." + ) + elif (include_sampler is None or not include_sampler) and session_includes_sampler: + raise ValueError( + "The provided QPC is compiled to run sampling on the QAIC device. " + "But the user did not enter `include_sampler`=True. Please make sure the input " + "is specified correctly." + ) + + return session_includes_sampler diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 233fb491a..14cdababe 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -26,7 +26,8 @@ To achieve this, we have 2 levels of APIs, with different levels of abstraction. | [Vision Language Model](QEFFAutoModelForImageTextToText) | Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text_inference.py) for more **details**. | | [Speech Sequence to Sequence Model](QEFFAutoModelForSpeechSeq2Seq) | Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/speech_to_text/run_whisper_speech_to_text.py) for more **details**. | | Support for FP8 Execution | Enables execution with FP8 precision, significantly improving performance and reducing memory usage for computational tasks. | -| Prefill caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. | +| Prefix caching | Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. | +| On Device Sampling | Enables sampling operations to be executed directly on the QAIC device rather than the host CPU for QEffForCausalLM models. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/on_device_sampling.py) for more **details**. | |Prompt-Lookup Decoding | Speeds up text generation by using overlapping parts of the input prompt and the generated text, making the process faster without losing quality. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/pld_spd_inference.py) for more **details**.| | [PEFT LoRA support](QEffAutoPeftModelForCausalLM) | Enables parameter-efficient fine-tuning using low-rank adaptation techniques, reducing the computational and memory requirements for fine-tuning large models. Refer [sample script](https://github.com/quic/efficient-transformers/blob/main/examples/peft_models.py) for more **details**. | | [QNN support](#qnn-compilation) | Enables compilation using QNN SDK, making Qeff adaptable for various backends in the future. | diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py new file mode 100644 index 000000000..00d8c2430 --- /dev/null +++ b/examples/on_device_sampling.py @@ -0,0 +1,267 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import argparse +import re +from pprint import pprint + +import numpy as np + +from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM +from QEfficient.utils import load_hf_tokenizer + + +def main(args, **kwargs): + pprint(args.__dict__) + + # Get sampling inputs + include_sampler = None + return_pdfs = None + max_top_k_ids = None + sampling_params = None + bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size + if args.override_qaic_config is not None: + include_sampler = args.override_qaic_config.get("aic_include_sampler", None) == "true" + if include_sampler is not None: + return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" + max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) + sampling_params = { + "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), + "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), + # "frequency_penalties": np.array(args.frequency_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), + "temperatures": np.array(args.temperature, dtype=np.float32).repeat(bs).reshape(-1, 1), + "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), + "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), + "min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1), + "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1), + } + qaic_config = { + k: v + for k, v in { + "include_sampler": include_sampler, + "return_pdfs": return_pdfs, + "max_top_k_ids": max_top_k_ids, + }.items() + if v is not None + } + print("qaic_config:") + pprint(qaic_config) + print("sampling_params:") + pprint(sampling_params) + + # Load model with On Device Sampler enabled + qeff_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=args.model_name, + continuous_batching=args.full_batch_size is not None, + qaic_config=qaic_config, + ) + print(f"{args.model_name} optimized for AI 100 \n", qeff_model) + + # Compile the model for inference + generated_qpc_path = qeff_model.compile( + prefill_seq_len=args.prompt_len, + ctx_len=args.ctx_len, + batch_size=args.batch_size, + full_batch_size=args.full_batch_size, + num_cores=args.num_cores, + num_devices=(0 if args.device_group is None else len(args.device_group)), + mxfp6_matmul=args.mxfp6, + mxint8_kv_cache=args.mxint8, + num_speculative_tokens=0, + **kwargs, + ) + print(f"Generated QPC file path: {generated_qpc_path}") + + # Generate texts from prompts + if not args.prompt: + args.prompt = [ + "Hi", + ] * bs + qeff_model.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=args.model_name), + prompts=args.prompt, + prompts_txt_file_path=args.prompts_txt_file_path, + device_id=args.device_group, + generation_len=args.generation_len, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, + ) + + +if __name__ == "__main__": + """ + Example usage: + 1. For continuous batching: + python3.10 examples/on_device_sampling.py \ + --model-name 'meta-llama/Llama-3.1-8B' \ + --prompt-len 128 \ + --ctx-len 256 \ + --generation-len 20 \ + --full-batch-size 2 \ + --device-group [0,1,2,3] \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --repetition-penalty 1.9 \ + --presence-penalty 0.8 \ + --temperature 0.67 \ + --top-k 54720 \ + --top-p 0.89 \ + --min-p 0.6 \ + --random-number 0.26 + + 2. For non-continuous batching: + python3.10 examples/on_device_sampling.py \ + --model-name 'meta-llama/Llama-3.1-8B' \ + --prompt-len 128 \ + --ctx-len 256 \ + --generation-len 20 \ + --batch-size 2 \ + --device-group [0,1,2,3] \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --repetition-penalty 1.9 \ + --presence-penalty 0.8 \ + --temperature 0.67 \ + --top-k 54720 \ + --top-p 0.89 \ + --min-p 0.6 \ + --random-number 0.26 + """ + + parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") + parser.add_argument( + "--model-name", "--model_name", required=True, default="meta-llama/Llama-3.1-8B", help="HF Model card name/id" + ) + parser.add_argument("--batch-size", "--batch_size", type=int, default=1, help="Batch size for text generation") + parser.add_argument( + "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation." + ) + parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.") + parser.add_argument( + "--mxfp6", + "--mxfp6_matmul", + "--mxfp6-matmul", + action="store_true", + help="Compress constant MatMul weights to MXFP6 E2M3, default is no compression", + ) + parser.add_argument( + "--mxint8", + "--mxint8_kv_cache", + "--mxint8-kv-cache", + action="store_true", + help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False", + ) + parser.add_argument( + "--num_cores", "--num-cores", type=int, required=True, help="Number of cores to compile on Cloud AI 100" + ) + parser.add_argument( + "--device_group", + "--device-group", + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + help="Cloud AI 100 device ids (comma-separated) e.g. [0,1]", + ) + parser.add_argument( + "--prompt", + type=lambda prompt: prompt.split("|"), + help="Input prompt, if executing for batch size>1, pass input prompts in single string but separate with pipe (|) symbol", + ) + parser.add_argument( + "--prompts_txt_file_path", + "--prompts-txt-file-path", + type=str, + help="File path for taking input prompts from txt file, sample prompts.txt file present in examples folder", + ) + parser.add_argument("--generation_len", "--generation-len", type=int, help="Number of tokens to generate") + + parser.add_argument( + "--full_batch_size", + "--full-batch-size", + type=int, + default=None, + help="Set full batch size to enable continuous batching mode, default is None", + ) + parser.add_argument( + "--override-qaic-config", + type=lambda configs: { + str(value[0]): value[1] if len(value) > 1 else True + for value in (re.split(r"[:=]", config.strip()) for config in re.split(r"[ ]+", configs.strip())) + }, + default=None, + help="override or set qaic device configuration.", + ) + + # ---On Device Sampling--- + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument( + "--repetition-penalty", + type=float, + default=None, + help="Sampling parameter that penalizes new tokens based on whether they appear in the " + "prompt and the generated text so far. Values > 1 encourage the model to use new tokens, " + "while values < 1 encourage the model to repeat tokens.", + ) + sampling_group.add_argument( + "--presence-penalty", + type=float, + default=None, + help="Sampling parameter that penalizes new tokens based on whether they appear in the " + "generated text so far. Values > 0 encourage the model to use new tokens, while values < " + "0 encourage the model to repeat tokens.", + ) + sampling_group.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling parameter that controls the randomness of the sampling. Lower" + "values make the model more deterministic, while higher values make" + "the model more random. Zero means greedy sampling.", + ) + sampling_group.add_argument( + "--top-k", + type=int, + default=None, + help="Sampling parameter that controls the number of top tokens to consider. Set to -1 to consider all tokens.", + ) + sampling_group.add_argument( + "--top-p", + type=float, + default=None, + help="Sampling parameter that controls the cumulative probability of the top tokens to " + "consider. Must be in (0, 1]. Set to 1.0 to consider all tokens.", + ) + sampling_group.add_argument( + "--min-p", + type=float, + default=None, + help="Sampling parameter that represents the minumum probability for a token to be " + "considered, relative to the probability of the most likely token. Must be in [0, 1]. " + "Set to 0.0 to disable this.", + ) + sampling_group.add_argument( + "--random-number", + type=float, + default=None, + help="Sampling parameter that represents the random seed to use for random sampling. Must be in [-1, 1].", + ) + args, compiler_options = parser.parse_known_args() + + compiler_options_dict = {} + for i in range(0, len(compiler_options)): + if compiler_options[i].startswith("--"): + key = compiler_options[i].lstrip("-").replace("-", "_") + value = ( + compiler_options[i + 1] + if i + 1 < len(compiler_options) and not compiler_options[i + 1].startswith("-") + else True + ) + compiler_options_dict[key] = value + + main(args, **compiler_options_dict) diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py new file mode 100644 index 000000000..3173af126 --- /dev/null +++ b/tests/transformers/sampler/test_sampler.py @@ -0,0 +1,360 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List + +import numpy as np +import pytest + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants + +configs = [ + pytest.param( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # model + Constants.INPUT_STR * 4, # prompts + 32, # prefill_seq_len + 256, # ctx_len + 20, # generation_len + 4, # full_batch_size + 1, # spec_length + ), +] + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + configs, +) +def test_sampler_transform( + model: str, + prompts: List[str], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + spec_length: int, +): + """ + Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the + sampling of next tokens at the device (instead of the host) and returns the + next tokens and/or probability distributions. + """ + # Export and compile QEfficient models + model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + }, + ) + model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": False, + "return_pdfs": False, + }, + ) + model_w_sampler_qpc_path: str = model_w_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_wo_sampler_qpc_path: str = model_wo_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Init qaic session + model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) + model_wo_sampler_session = QAICInferenceSession(model_wo_sampler_qpc_path) + + # Skip inputs/outputs buffers + model_w_sampler_session.skip_buffers(set([x for x in model_w_sampler_session.input_names if x.startswith("past_")])) + model_w_sampler_session.skip_buffers( + set([x for x in model_w_sampler_session.output_names if x.endswith("_RetainedState")]) + ) + model_wo_sampler_session.skip_buffers( + set([x for x in model_wo_sampler_session.input_names if x.startswith("past_")]) + ) + model_wo_sampler_session.skip_buffers( + set([x for x in model_wo_sampler_session.output_names if x.endswith("_RetainedState")]) + ) + + # Validate sampler inputs + sampler_inputs = Constants.SAMPLER_INPUTS + for input_name in sampler_inputs: + assert input_name in model_w_sampler_session.input_names, ( + f"Sampler input {input_name} not found in QPC compiled with On Device Sampler" + ) + assert input_name not in model_wo_sampler_session.input_names, ( + f"Sampler input {input_name} found in QPC compiled without On Device Sampler" + ) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + configs, +) +def test_greedy_sampling( + model: str, + prompts: List[str], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + spec_length: int, +): + """ + Test greedy sampling with QPC compiled with and without On Device Sampling. + """ + # Export and compile QEfficient models + model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + }, + ) + model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": False, + "return_pdfs": False, + }, + ) + model_w_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_wo_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Generate texts from prompts + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + model_w_sampler_exec_info = model_w_sampler.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + sampling_params={ + "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + }, + ) + model_wo_sampler_exec_info = model_wo_sampler.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=False, + return_pdfs=False, + sampling_params=None, + ) + + # Compare generated texts and ids + assert model_w_sampler_exec_info.generated_texts == model_wo_sampler_exec_info.generated_texts, ( + "Generated texts do not match" + ) + assert (model_w_sampler_exec_info.generated_ids == model_wo_sampler_exec_info.generated_ids).all(), ( + "Generated ids do not match" + ) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + configs, +) +def test_random_sampling( + model: str, + prompts: List[str], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + spec_length: int, +): + """ + Test random sampling with QPC compiled with and without On Device Sampling. + """ + # Export and compile QEfficient models + model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + }, + ) + model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": False, + "return_pdfs": False, + }, + ) + model_w_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_wo_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length - 1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Generate texts from prompts + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + model_w_sampler_exec_info = model_w_sampler.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + sampling_params={ + "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + }, + ) + model_wo_sampler_exec_info = model_wo_sampler.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=False, + return_pdfs=False, + sampling_params=None, + ) + + # Compare generated texts + golden_texts = { + "w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have either as", + "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", + } + golden_ids = { + "w_sampler": [ + [ + 21380, + 322, + 590, + 25448, + 2927, + 29892, + 19963, + 2654, + 29879, + 470, + 3708, + 2701, + 313, + 29902, + 508, + 30010, + 29873, + 505, + 2845, + 408, + ] + ], + "wo_sampler": [ + [ + 2259, + 7075, + 322, + 306, + 626, + 263, + 7047, + 22055, + 29889, + 306, + 505, + 1063, + 1985, + 297, + 278, + 13661, + 363, + 278, + 4940, + 29871, + ] + ], + } + for i in range(full_batch_size): + assert ( + tokenizer.decode(model_w_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["w_sampler"] + ), "Sampler generated texts does not match" + assert (model_w_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["w_sampler"]).all(), ( + "Sampler generated ids do not match" + ) + assert ( + tokenizer.decode(model_wo_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["wo_sampler"] + ), "Without sampler generated texts does not match" + assert (model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"]).all(), ( + "Without sampler generated ids do not match" + )