Skip to content

Commit 40af171

Browse files
committed
removing gated llama3.3-70B model from test + lint/format
Signed-off-by: Vaibhav Verma <[email protected]>
1 parent a1915c9 commit 40af171

File tree

5 files changed

+22
-14
lines changed

5 files changed

+22
-14
lines changed

QEfficient/customop/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFuncBlockedKV, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D
8+
from QEfficient.customop.ctx_scatter_gather import (
9+
CtxGatherFunc,
10+
CtxGatherFunc3D,
11+
CtxGatherFuncBlockedKV,
12+
CtxScatterFunc,
13+
CtxScatterFunc3D,
14+
)
915
from QEfficient.customop.ctx_scatter_gather_cb import (
10-
CtxGatherFuncCB,
1116
CtxGatherFuncBlockedKVCB,
17+
CtxGatherFuncCB,
1218
CtxGatherFuncCB3D,
1319
CtxScatterFuncCB,
1420
CtxScatterFuncCB3D,

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def setup_context(ctx, inputs, outputs):
146146
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
147147
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)
148148

149+
149150
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
150151
def CtxGatherBlockedKV(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
151152
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])

QEfficient/transformers/cache_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
from QEfficient.customop import (
1616
CtxGatherFunc,
17-
CtxGatherFuncBlockedKV,
1817
CtxGatherFunc3D,
19-
CtxGatherFuncCB,
18+
CtxGatherFuncBlockedKV,
2019
CtxGatherFuncBlockedKVCB,
20+
CtxGatherFuncCB,
2121
CtxGatherFuncCB3D,
2222
CtxScatterFunc,
2323
CtxScatterFunc3D,

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from functools import partial
98
import warnings
9+
from functools import partial
1010
from types import MethodType
1111
from typing import Callable, Optional, Tuple, Union
1212

tests/transformers/models/test_causal_lm_models.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@
6868
]
6969

7070
test_models_blockedKV = [
71-
"meta-llama/Llama-3.3-70B-Instruct",
71+
# "meta-llama/Llama-3.3-70B-Instruct",
72+
"meta-llama/Llama-3.2-1B",
7273
]
7374

7475

@@ -248,7 +249,11 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
248249
pytorch_hf_tokens = [pytorch_hf_tokens for _ in range(full_batch_size)]
249250

250251
qeff_model = QEFFAutoModelForCausalLM(
251-
model_hf, continuous_batching=True, is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config
252+
model_hf,
253+
continuous_batching=True,
254+
is_tlm=is_tlm,
255+
pretrained_model_name_or_path=model_name,
256+
qaic_config=qaic_config,
252257
)
253258
onnx_model_path = qeff_model.export()
254259

@@ -505,9 +510,8 @@ def test_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
505510
n_layer = get_custom_n_layers(model_name)
506511

507512
qaic_config = dict(num_kv_blocks=Constants.NUM_KV_BLOCKS)
508-
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
509-
model_name=model_name, n_layer=n_layer, qaic_config=qaic_config
510-
)
513+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config)
514+
511515

512516
@pytest.mark.parametrize("model_name", test_models_blockedKV)
513517
def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
@@ -518,7 +522,4 @@ def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
518522
"""
519523
n_layer = get_custom_n_layers(model_name)
520524

521-
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
522-
model_name=model_name, n_layer=n_layer
523-
)
524-
525+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer)

0 commit comments

Comments
 (0)