Skip to content

Commit f48fbfc

Browse files
committed
Adding ccl_enabled flag during model loading and passing CCL lists during compilation process
Signed-off-by: Vahid Janfaza <[email protected]>
1 parent f25cd52 commit f48fbfc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,7 @@ def __init__(
936936
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
937937
self.continuous_batching = continuous_batching
938938
self.ccl_enabled = ccl_enabled
939+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
939940
self.input_shapes, self.output_names = None, None
940941

941942
@property
@@ -1178,7 +1179,6 @@ def compile(
11781179
output_names = self.model.get_output_names(kv_offload=True)
11791180

11801181
# if ccl_enabled is True read Compute-Context-Length lists
1181-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
11821182
if self.ccl_enabled:
11831183
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
11841184
logger.warning(
@@ -1648,6 +1648,7 @@ def __init__(
16481648
self.model.config.use_cache = True
16491649
self.hash_params["qeff_auto_class"] = self.__class__.__name__
16501650
self.ccl_enabled = ccl_enabled
1651+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
16511652

16521653
@classmethod
16531654
def from_pretrained(
@@ -1811,7 +1812,6 @@ def compile(
18111812
output_names = self.model.get_output_names()
18121813

18131814
# if ccl_enabled is True read Compute-Context-Length lists
1814-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
18151815
if self.ccl_enabled:
18161816
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
18171817
logger.warning(
@@ -2401,6 +2401,7 @@ def __init__(
24012401

24022402
self.hash_params["qeff_auto_class"] = self.__class__.__name__
24032403
self.ccl_enabled = ccl_enabled
2404+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
24042405

24052406
# ---Sampling---
24062407
# Note: SamplerTransform should be applied after all other transforms
@@ -2938,7 +2939,6 @@ def compile(
29382939
"""
29392940

29402941
# if ccl_enabled is True read Compute-Context-Length lists
2941-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
29422942
if self.ccl_enabled:
29432943
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
29442944
logger.warning(

0 commit comments

Comments
 (0)