Skip to content

Commit 337757c

Browse files
[BUGFIX] Patch for issues with export via replicate_kv_heads script CLI for CB (#646)
InputHandler has changes to create position_ids based on CB batch size. Signed-off-by: Dhiraj Kumar Sah <[email protected]>
1 parent 529b530 commit 337757c

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

QEfficient/exporter/export_hf_to_cloud_ai_100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def export_kvstyle_transformed_model_to_onnx(
202202
batch_size=len(Constants.INPUT_STR),
203203
tokenizer=tokenizer,
204204
config=transformed_model.config,
205-
prompt=Constants.INPUT_STR,
205+
prompt=Constants.INPUT_STR * (full_batch_size if full_batch_size else 1),
206206
prompt_len=Constants.PROMPT_LEN,
207207
ctx_len=seq_len,
208208
full_batch_size=full_batch_size,

QEfficient/utils/generate_inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def prepare_pytorch_inputs(self):
6868
batch_size, input_len = input_ids.shape
6969
inputs.pop("attention_mask")
7070
inputs.pop("token_type_ids", None)
71-
position_ids = torch.arange(input_len).view(1, -1)
71+
usable_bs = self.full_batch_size if self.full_batch_size else 1
72+
position_ids = torch.arange(input_len).view(1, input_len).repeat(usable_bs, 1)
7273
inputs["input_ids"] = torch.concat(
7374
[
7475
input_ids,

0 commit comments

Comments
 (0)