-
Notifications
You must be signed in to change notification settings - Fork 59
Unit Tests for On Device Sampling #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
8417d8f
Add sampler transform test
quic-sanising 27d8dd5
Merge branch 'main' into ods-unit-tests
067f9b5
Add example script
931860f
Update docs
79b6c95
Enable On Device Sampling for _continuous_batching_execution()
75eac30
Disable On Device Sampling for _regular_model_execution()
eb6e2eb
Use same sampling parameters for each sequence in a batch
48b35e3
Enable On Device Sampling for _regular_model_execution()
c83a631
Add test for greedy sampling
f698a24
Add test for random sampling
7b34a07
Remove else block
5fa7269
Merge branch 'main' into ods-unit-tests
0ee201a
Reformat code
c074768
Merge branch 'quic:main' into ods-unit-tests
quic-sanising 115505e
Move sampling operations, inputs, and validation functions to utils
3ac7503
Change model to TinyLlama
02669e0
Add header
137cc4a
Reformat code
54a926a
Merge branch 'quic:main' into ods-unit-tests
quic-sanising 6acf446
Update linter
6083f5b
Merge branch 'quic:main' into ods-unit-tests
quic-sanising c2d7e83
Remove device_id
1069109
Remove redundant line
7d67132
Merge branch 'quic:main' into ods-unit-tests
quic-sanising 0e3f257
Merge branch 'main' into ods-unit-tests
908e67e
Remove redundant reinitialization of output buffers
a8e55da
Merge branch 'main' into ods-unit-tests
f3f89d3
Add qaic_config to model hash
81ae15a
Merge branch 'main' into ods-unit-tests
c485bfd
Change config
0e3b383
Remove pretrained_model_name_or_path from qaic_config
7d91470
Revert changes to model hash
e36add0
Added qaic_config to hash parameters via inclusion list.
quic-dhirajku 127ec74
Added qaic_config in manual hash tests for causal_lm dummy models.
quic-dhirajku dad96ca
Use different config for each test
126cbb0
Duplicate prompt arg fix in infer
quic-rishinr 30d025d
Merge branch 'main' into ods-unit-tests
0b4575b
Change config of greedy tests
7f1d5f4
Merge branch 'main' into ods-unit-tests
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
from collections import deque | ||
from dataclasses import dataclass | ||
from time import perf_counter | ||
from typing import Dict, List, Optional, Tuple, Union | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import transformers | ||
|
@@ -322,6 +322,9 @@ def cloud_ai_100_exec_kv( | |
automation=False, | ||
prompt_to_lora_id_mapping: Optional[List[int]] = None, | ||
is_tlm: bool = False, | ||
include_sampler: bool = False, | ||
return_pdfs: bool = False, | ||
sampling_params: Optional[Dict[str, Any]] = None, | ||
): | ||
""" | ||
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. | ||
|
@@ -342,6 +345,15 @@ def cloud_ai_100_exec_kv( | |
:Write_io_dir (str): Path to write the input and output files. ``Defaults to None``. | ||
:automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``. | ||
:prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter. | ||
:include_sampler (bool): Enable/Disable sampling of next tokens. | ||
:return_pdfs (bool): Return probability distributions along with sampled | ||
quic-sanising marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
next tokens. For Speculative Decoding Target Language Model, | ||
`return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative | ||
Decoding Draft Language Model and `return_pdfs`=False for regular model. | ||
sampling_params (Dict[str, Any]): A dictionary of sampling parameters supported by the QAIC backend. | ||
The dictionary should contain the following keys: | ||
`repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, | ||
`min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1). | ||
|
||
Returns: | ||
:CloudAI100ExecInfo: Object holding execution output and performance details. | ||
|
@@ -372,6 +384,9 @@ def cloud_ai_100_exec_kv( | |
write_io_dir=write_io_dir, | ||
full_batch_size=full_batch_size, | ||
is_tlm=is_tlm, | ||
include_sampler=include_sampler, | ||
return_pdfs=return_pdfs, | ||
sampling_params=sampling_params, | ||
) | ||
if full_batch_size is None: | ||
exec_info = [ | ||
|
@@ -411,14 +426,59 @@ def __init__( | |
enable_debug_logs: bool = False, | ||
write_io_dir: Optional[str] = None, | ||
is_tlm: Optional[int] = None, | ||
include_sampler: bool = False, | ||
return_pdfs: bool = False, | ||
sampling_params: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
self._ctx_len = ctx_len | ||
self._write_io_dir = write_io_dir | ||
self.is_tlm = is_tlm | ||
self.include_sampler = include_sampler | ||
quic-sanising marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self.return_pdfs = return_pdfs | ||
self.sampling_params = sampling_params | ||
|
||
# Load QPC | ||
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) | ||
|
||
# Validate sampler inputs for On-Device Sampling | ||
sampler_inputs = [ | ||
quic-sanising marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"last_accepted_output_tokens", | ||
"repetition_penalties", | ||
"presence_penalties", | ||
"temperatures", | ||
"top_ks", | ||
"top_ps", | ||
"min_ps", | ||
"random_numbers", | ||
] | ||
count = 0 | ||
for session_input_name in self._session.input_names: | ||
if session_input_name in sampler_inputs: | ||
count += 1 | ||
quic-sanising marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if count == len(sampler_inputs): | ||
self.include_sampler = True | ||
break | ||
if count == 0: | ||
quic-sanising marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self.include_sampler = False | ||
elif count < len(sampler_inputs): | ||
raise ValueError( | ||
"The provided QPC does not have the required number of inputs to run sampling " | ||
f"on the QAIC device (only {count}/{len(sampler_inputs)} inputs provided). Partial " | ||
quic-sanising marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"sampling support is not available. Please check the QPC and try again." | ||
) | ||
|
||
if include_sampler and not self.include_sampler: | ||
|
||
logger.warning_once( | ||
"User entered `include_sampler`=True. But the provided QPC is not compiled " | ||
"to run sampling on the QAIC device. Falling back to the PyTorch backend." | ||
) | ||
elif (include_sampler is None or not include_sampler) and self.include_sampler: | ||
raise ValueError( | ||
"The provided QPC is compiled to run sampling on the QAIC device. " | ||
"But the user did not enter `include_sampler`=True. Please make sure the input " | ||
"is specified correctly." | ||
) | ||
|
||
# Fetch the variables from the QPC | ||
self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size | ||
self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len() | ||
|
@@ -523,10 +583,17 @@ def _fetch_vocab_size( | |
Returns: | ||
vocab_size: The vocabulary size fetched from the session's allowed shapes. | ||
""" | ||
key = ( | ||
"probs" | ||
if self.include_sampler and self.return_pdfs | ||
else "next_tokens" | ||
if self.include_sampler | ||
else "logits" | ||
) | ||
if self._session.allowed_shapes: | ||
return [x[self._session.binding_index_map["logits"]] for x in self._session.allowed_shapes][0][1][2] | ||
return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2] | ||
|
||
return self._session.bindings[self._session.binding_index_map["logits"]].dims[2] | ||
return self._session.bindings[self._session.binding_index_map[key]].dims[2] | ||
|
||
def _fetch_generation_len(self, generation_len, max_gen_len): | ||
""" | ||
|
@@ -574,6 +641,21 @@ def prepare_decode_inputs(self): | |
decode_inputs["position_ids"] = self.decode_pos_ids | ||
if self.batch_index is not None: | ||
decode_inputs["batch_index"] = self.batch_index | ||
if self.include_sampler: | ||
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] | ||
for op in [ | ||
quic-sanising marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"repetition_penalties", | ||
"presence_penalties", | ||
"temperatures", | ||
"top_ks", | ||
"top_ps", | ||
"min_ps", | ||
"random_numbers", | ||
]: | ||
if self.batch_index is not None: | ||
decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()] | ||
else: | ||
decode_inputs[op] = self.sampling_params[op] | ||
|
||
if self._prompt_to_lora_id_mapping_decode: | ||
if self.full_batch_size: | ||
|
@@ -589,21 +671,24 @@ def prepare_decode_inputs(self): | |
|
||
def _fetch_next_token_id(self, outputs): | ||
""" | ||
Fetches the next token ID from the model's output logits. | ||
The method identifies the token with the highest probability using argmax along the last dimension. | ||
Fetches the next token ID from the model's output. | ||
|
||
Args: | ||
outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size). | ||
outputs (dict): A dictionary containing the model's output. | ||
|
||
Returns: | ||
numpy.ndarray: An array of the next token IDs for each sequence in the batch. | ||
""" | ||
logits = outputs["logits"] | ||
if len(logits.shape) == 2: | ||
logits = np.expand_dims(logits, 1) | ||
|
||
# Get output token | ||
next_token_id = logits.argmax(2) | ||
return next_token_id | ||
if self.include_sampler: | ||
if self.return_pdfs: | ||
return outputs["probs"].argmax(2) | ||
else: | ||
return outputs["next_tokens"].reshape(outputs["next_tokens"].shape[0], outputs["next_tokens"].shape[1]) | ||
else: | ||
logits = outputs["logits"] | ||
if len(logits.shape) == 2: | ||
logits = np.expand_dims(logits, 1) | ||
return logits.argmax(2) | ||
|
||
def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_length): | ||
""" | ||
|
@@ -673,6 +758,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): | |
|
||
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) | ||
|
||
def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): | ||
""" | ||
Sets the sizes of the output buffers. | ||
|
||
Args: | ||
batch_size (int): The batch size. | ||
""" | ||
if self.include_sampler: | ||
if self.return_pdfs: | ||
probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) | ||
self._session.set_buffers({"probs": probs_out_placeholder}) | ||
next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64) | ||
self._session.set_buffers({"next_tokens": next_tokens_out_placeholder}) | ||
else: | ||
logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) | ||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
|
||
def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): | ||
""" | ||
Runs prefill for a given prompt and generation length. | ||
|
@@ -702,9 +804,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i | |
max_gen_len = self._ctx_len - position_ids.max() | ||
generation_len = self._fetch_generation_len(generation_len, max_gen_len) | ||
|
||
# Set the prefill logic buffer | ||
logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32) | ||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
# Set the prefill output buffers | ||
self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) | ||
|
||
inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) | ||
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) | ||
|
@@ -714,6 +815,21 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i | |
inputs["batch_index"] = decode_batch_id | ||
if self.is_tlm: | ||
inputs["num_logits_to_keep"] = np.zeros((1, 1)) | ||
if self.include_sampler: | ||
inputs["last_accepted_output_tokens"] = inputs["input_ids"] | ||
for op in [ | ||
"repetition_penalties", | ||
"presence_penalties", | ||
"temperatures", | ||
"top_ks", | ||
"top_ps", | ||
"min_ps", | ||
"random_numbers", | ||
]: | ||
if decode_batch_id is not None: | ||
inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] | ||
else: | ||
inputs[op] = self.sampling_params[op] | ||
|
||
if self._prompt_to_lora_id_mapping_prefill: | ||
if self.full_batch_size: | ||
|
@@ -732,6 +848,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i | |
chunk_inputs["position_ids"] = inputs["position_ids"][ | ||
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len | ||
] | ||
if self.include_sampler: | ||
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] | ||
outputs = self._session.run(chunk_inputs) | ||
if self._write_io_dir is not None: | ||
write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) | ||
|
@@ -753,11 +871,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): | |
|
||
""" | ||
|
||
# Set logits placeholder for decode | ||
logits_out_placeholder = np.zeros( | ||
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 | ||
# Set output placeholders for decode | ||
self._set_output_buffers( | ||
batch_size=self.full_batch_size, | ||
sequence_length=self._decode_seq_len, | ||
) | ||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
|
||
# Generate flag for tracking progress for each batch ID | ||
current_decode_ongoing = np.full((self.full_batch_size, 1), True) | ||
|
||
|
@@ -775,10 +894,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): | |
outputs = self._session.run(decode_inputs) | ||
|
||
# Prepare inputs for next iteration | ||
logits = outputs["logits"] | ||
if len(logits.shape) == 2: | ||
logits = np.expand_dims(logits, 1) | ||
next_token_id = logits.argmax(2) | ||
next_token_id = self._fetch_next_token_id(outputs) | ||
quic-sanising marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for decode_batch_id in range(self.full_batch_size): | ||
if ( | ||
|
@@ -800,7 +916,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): | |
self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1) | ||
generated_id_current_index[decode_batch_id] = 1 | ||
|
||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
self._set_output_buffers( | ||
batch_size=self.full_batch_size, | ||
sequence_length=self._decode_seq_len, | ||
) | ||
decode_pause_time += perf_counter() - start | ||
|
||
if self._prompt_to_lora_id_mapping_decode: | ||
|
@@ -817,6 +936,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): | |
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( | ||
next_token_id[decode_batch_id, -1] | ||
) | ||
if self.include_sampler: | ||
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] | ||
|
||
generated_id_current_index[decode_batch_id] += 1 | ||
|
||
|
@@ -840,6 +961,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform | |
(self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 | ||
) | ||
self._session.set_buffers({"logits": logits_out_placeholder}) | ||
else: | ||
self._set_output_buffers( | ||
batch_size=self.batch_size, | ||
sequence_length=self._decode_seq_len, | ||
) | ||
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id | ||
num_token = 0 | ||
for num_token in range(1, generation_len): | ||
|
@@ -852,10 +978,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform | |
self._write_io_dir = None | ||
|
||
# Prepare inputs for next iteration | ||
decode_inputs["input_ids"] = outputs["logits"].argmax(2) | ||
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) | ||
decode_inputs["position_ids"][:, -1] += 1 | ||
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] | ||
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id | ||
if self.include_sampler: | ||
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] | ||
|
||
if finished_sequences.all(): | ||
break | ||
|
@@ -905,9 +1033,22 @@ def __init__( | |
enable_debug_logs: bool = False, | ||
write_io_dir: Optional[str] = None, | ||
is_tlm: bool = False, | ||
include_sampler: bool = False, | ||
return_pdfs: bool = False, | ||
sampling_params: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
self._qaic_model = QEffTextGenerationBase( | ||
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm | ||
tokenizer=tokenizer, | ||
qpc_path=qpc_path, | ||
full_batch_size=full_batch_size, | ||
ctx_len=ctx_len, | ||
device_id=device_id, | ||
enable_debug_logs=enable_debug_logs, | ||
write_io_dir=write_io_dir, | ||
is_tlm=is_tlm, | ||
include_sampler=include_sampler, | ||
return_pdfs=return_pdfs, | ||
sampling_params=sampling_params, | ||
) | ||
self._full_batch_size = self._qaic_model.full_batch_size | ||
self._tokenizer = self._qaic_model.tokenizer | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.