Skip to content

Unit Tests for On Device Sampling #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8417d8f
Add sampler transform test
quic-sanising Jun 18, 2025
27d8dd5
Merge branch 'main' into ods-unit-tests
Jun 30, 2025
067f9b5
Add example script
Jun 30, 2025
931860f
Update docs
Jun 30, 2025
79b6c95
Enable On Device Sampling for _continuous_batching_execution()
Jun 30, 2025
75eac30
Disable On Device Sampling for _regular_model_execution()
Jul 1, 2025
eb6e2eb
Use same sampling parameters for each sequence in a batch
Jul 1, 2025
48b35e3
Enable On Device Sampling for _regular_model_execution()
Jul 2, 2025
c83a631
Add test for greedy sampling
Jul 3, 2025
f698a24
Add test for random sampling
Jul 3, 2025
7b34a07
Remove else block
Jul 3, 2025
5fa7269
Merge branch 'main' into ods-unit-tests
Jul 3, 2025
0ee201a
Reformat code
Jul 3, 2025
c074768
Merge branch 'quic:main' into ods-unit-tests
quic-sanising Jul 24, 2025
115505e
Move sampling operations, inputs, and validation functions to utils
Aug 4, 2025
3ac7503
Change model to TinyLlama
Aug 4, 2025
02669e0
Add header
Aug 4, 2025
137cc4a
Reformat code
Aug 4, 2025
54a926a
Merge branch 'quic:main' into ods-unit-tests
quic-sanising Aug 4, 2025
6acf446
Update linter
Aug 4, 2025
6083f5b
Merge branch 'quic:main' into ods-unit-tests
quic-sanising Aug 6, 2025
c2d7e83
Remove device_id
Aug 6, 2025
1069109
Remove redundant line
Aug 6, 2025
7d67132
Merge branch 'quic:main' into ods-unit-tests
quic-sanising Aug 19, 2025
0e3f257
Merge branch 'main' into ods-unit-tests
Aug 20, 2025
908e67e
Remove redundant reinitialization of output buffers
Aug 20, 2025
a8e55da
Merge branch 'main' into ods-unit-tests
Aug 22, 2025
f3f89d3
Add qaic_config to model hash
Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.2
rev: v0.12.7
hooks:
# Run the linter.
- id: ruff
Expand Down
141 changes: 114 additions & 27 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
from collections import deque
from dataclasses import dataclass
from time import perf_counter
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import transformers
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import padding_check_and_fix
from QEfficient.utils.constants import Constants
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.sampler_utils import validate_sampler_inputs


@dataclass
Expand Down Expand Up @@ -322,6 +324,9 @@ def cloud_ai_100_exec_kv(
automation=False,
prompt_to_lora_id_mapping: Optional[List[int]] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
):
"""
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
Expand All @@ -342,6 +347,15 @@ def cloud_ai_100_exec_kv(
:Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
:automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
:prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
:include_sampler (bool, default=False): Enable/Disable sampling of next tokens.
:return_pdfs (bool, default=False): Return probability distributions along with sampled
next tokens. For Speculative Decoding Target Language Model,
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
Decoding Draft Language Model and `return_pdfs`=False for regular model.
sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
The dictionary should contain the following keys:
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
`min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1).

Returns:
:CloudAI100ExecInfo: Object holding execution output and performance details.
Expand Down Expand Up @@ -372,6 +386,9 @@ def cloud_ai_100_exec_kv(
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
sampling_params=sampling_params,
)
if full_batch_size is None:
exec_info = [
Expand Down Expand Up @@ -411,14 +428,24 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: Optional[int] = None,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._ctx_len = ctx_len
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
self.sampling_params = sampling_params

# Load QPC
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs)

# Validate sampler inputs for On-Device Sampling
self.include_sampler = validate_sampler_inputs(
session_inputs=set(self._session.input_names), include_sampler=include_sampler
)

# Fetch the variables from the QPC
self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size
self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len()
Expand Down Expand Up @@ -523,10 +550,17 @@ def _fetch_vocab_size(
Returns:
vocab_size: The vocabulary size fetched from the session's allowed shapes.
"""
key = (
"probs"
if self.include_sampler and self.return_pdfs
else "next_tokens"
if self.include_sampler
else "logits"
)
if self._session.allowed_shapes:
return [x[self._session.binding_index_map["logits"]] for x in self._session.allowed_shapes][0][1][2]
return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2]

return self._session.bindings[self._session.binding_index_map["logits"]].dims[2]
return self._session.bindings[self._session.binding_index_map[key]].dims[2]

def _fetch_generation_len(self, generation_len, max_gen_len):
"""
Expand Down Expand Up @@ -574,6 +608,13 @@ def prepare_decode_inputs(self):
decode_inputs["position_ids"] = self.decode_pos_ids
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
if self.batch_index is not None:
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()]
else:
decode_inputs[op] = self.sampling_params[op]

if self._prompt_to_lora_id_mapping_decode:
if self.full_batch_size:
Expand All @@ -589,21 +630,24 @@ def prepare_decode_inputs(self):

def _fetch_next_token_id(self, outputs):
"""
Fetches the next token ID from the model's output logits.
The method identifies the token with the highest probability using argmax along the last dimension.
Fetches the next token ID from the model's output.

Args:
outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size).
outputs (dict): A dictionary containing the model's output.

Returns:
numpy.ndarray: An array of the next token IDs for each sequence in the batch.
"""
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)

# Get output token
next_token_id = logits.argmax(2)
return next_token_id
if self.include_sampler:
if self.return_pdfs:
return outputs["probs"].argmax(2)
else:
return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1])
else:
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
return logits.argmax(2)

def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length):
"""
Expand Down Expand Up @@ -673,6 +717,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):

_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)

def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1):
"""
Sets the sizes of the output buffers.

Args:
batch_size (int): The batch size.
"""
if self.include_sampler:
if self.return_pdfs:
probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"probs": probs_out_placeholder})
next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64)
self._session.set_buffers({"next_tokens": next_tokens_out_placeholder})
else:
logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"logits": logits_out_placeholder})

def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None):
"""
Runs prefill for a given prompt and generation length.
Expand Down Expand Up @@ -702,9 +763,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
max_gen_len = self._ctx_len - position_ids.max()
generation_len = self._fetch_generation_len(generation_len, max_gen_len)

# Set the prefill logic buffer
logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"logits": logits_out_placeholder})
# Set the prefill output buffers
self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1)

inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
Expand All @@ -714,6 +774,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
inputs["batch_index"] = decode_batch_id
if self.is_tlm:
inputs["num_logits_to_keep"] = np.zeros((1, 1))
if self.include_sampler:
inputs["last_accepted_output_tokens"] = inputs["input_ids"]
for op in Constants.SAMPLER_OPS:
if decode_batch_id is not None:
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()]
else:
inputs[op] = self.sampling_params[op]

if self._prompt_to_lora_id_mapping_prefill:
if self.full_batch_size:
Expand All @@ -732,6 +799,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
chunk_inputs["position_ids"] = inputs["position_ids"][
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
]
if self.include_sampler:
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]
outputs = self._session.run(chunk_inputs)
if self._write_io_dir is not None:
write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False)
Expand All @@ -753,11 +822,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):

"""

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
# Set output placeholders for decode
self._set_output_buffers(
batch_size=self.full_batch_size,
sequence_length=self._decode_seq_len,
)
self._session.set_buffers({"logits": logits_out_placeholder})

# Generate flag for tracking progress for each batch ID
current_decode_ongoing = np.full((self.full_batch_size, 1), True)

Expand All @@ -775,10 +845,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
outputs = self._session.run(decode_inputs)

# Prepare inputs for next iteration
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)
next_token_id = self._fetch_next_token_id(outputs)

for decode_batch_id in range(self.full_batch_size):
if (
Expand All @@ -800,7 +867,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1)
generated_id_current_index[decode_batch_id] = 1

self._session.set_buffers({"logits": logits_out_placeholder})
self._set_output_buffers(
batch_size=self.full_batch_size,
sequence_length=self._decode_seq_len,
)
decode_pause_time += perf_counter() - start

if self._prompt_to_lora_id_mapping_decode:
Expand All @@ -817,6 +887,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = (
next_token_id[decode_batch_id, -1]
)
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]

generated_id_current_index[decode_batch_id] += 1

Expand Down Expand Up @@ -852,10 +924,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
self._write_io_dir = None

# Prepare inputs for next iteration
decode_inputs["input_ids"] = outputs["logits"].argmax(2)
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
decode_inputs["position_ids"][:, -1] += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]

if finished_sequences.all():
break
Expand Down Expand Up @@ -905,9 +979,22 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._qaic_model = QEffTextGenerationBase(
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
tokenizer=tokenizer,
qpc_path=qpc_path,
full_batch_size=full_batch_size,
ctx_len=ctx_len,
device_id=device_id,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
is_tlm=is_tlm,
include_sampler=include_sampler,
return_pdfs=return_pdfs,
sampling_params=sampling_params,
)
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
Expand Down
2 changes: 2 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,7 @@ def __init__(
self.model, transformed = SamplerTransform.apply(self.model, qaic_config, **kwargs)
if self.is_tlm:
self.model.qaic_config["return_pdfs"] = True
self.hash_params["qaic_config"] = self.model.qaic_config # Explicitly add `qaic_config` to model hash

@property
def model_name(self) -> str:
Expand Down Expand Up @@ -1827,6 +1828,7 @@ def generate(
device_id=device_id,
generation_len=generation_len,
is_tlm=self.is_tlm,
**kwargs,
)
else:
raise NotImplementedError("Only AI_100 runtime is supported right now via generate API")
Expand Down
3 changes: 3 additions & 0 deletions QEfficient/transformers/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def sampler_forward(
batch_size, spec_length, vocab_size = logits.shape
logits = logits.reshape(-1, vocab_size) # Reshape tensor to 2D

if batch_index is None: # Regular model execution
batch_index = torch.arange(batch_size).view(-1, 1)

batch_index_reshaped = batch_index.view(-1)
# Prefill
past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path(
Expand Down
10 changes: 10 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ class Constants:
MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
NUM_SPECULATIVE_TOKENS = 2
MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS
SAMPLER_OPS = {
"repetition_penalties",
"presence_penalties",
"temperatures",
"top_ks",
"top_ps",
"min_ps",
"random_numbers",
}
SAMPLER_INPUTS = SAMPLER_OPS | {"last_accepted_output_tokens"}
SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version.
SDK_PLATFORM_XML = (
"/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version.
Expand Down
Loading
Loading