Skip to content

Commit 6dc9d41

Browse files
committed
Adding ccl_enabled flag during model loading and passing CCL lists during compilation process
Signed-off-by: Vahid Janfaza <[email protected]>
1 parent d437aac commit 6dc9d41

File tree

15 files changed

+80
-53
lines changed

15 files changed

+80
-53
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@ def __init__(
909909
self,
910910
model: nn.Module,
911911
continuous_batching: bool = False,
912-
ccl_enabled: bool = False,
912+
qaic_config: Optional[dict] = None,
913913
**kwargs,
914914
):
915915
"""
@@ -935,7 +935,7 @@ def __init__(
935935
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
936936
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
937937
self.continuous_batching = continuous_batching
938-
self.ccl_enabled = ccl_enabled
938+
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
939939
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
940940
self.input_shapes, self.output_names = None, None
941941

@@ -955,7 +955,7 @@ def model_name(self) -> str:
955955
return mname
956956

957957
@classmethod
958-
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
958+
def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs):
959959
"""
960960
Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path.
961961
@@ -980,13 +980,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
980980
logger.warning("Updating low_cpu_mem_usage=False")
981981

982982
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
983-
ccl_enabled = kwargs.pop("ccl_enabled", None)
984983

985984
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
986985
return cls(
987986
model,
988987
pretrained_model_name_or_path=pretrained_model_name_or_path,
989-
ccl_enabled=ccl_enabled,
988+
qaic_config=qaic_config,
990989
**kwargs,
991990
)
992991

@@ -1190,8 +1189,9 @@ def compile(
11901189

11911190
# For supporting VLLM and Disaggregated with CCL
11921191
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1193-
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1194-
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
1192+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1193+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
1194+
)
11951195

11961196
specializations, compiler_options = self.model.get_specializations(
11971197
batch_size=batch_size,
@@ -1614,7 +1614,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
16141614
def __init__(
16151615
self,
16161616
model: nn.Module,
1617-
ccl_enabled: bool = False,
1617+
qaic_config: Optional[dict] = None,
16181618
**kwargs,
16191619
):
16201620
"""
@@ -1647,13 +1647,14 @@ def __init__(
16471647
else:
16481648
self.model.config.use_cache = True
16491649
self.hash_params["qeff_auto_class"] = self.__class__.__name__
1650-
self.ccl_enabled = ccl_enabled
1650+
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
16511651
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
16521652

16531653
@classmethod
16541654
def from_pretrained(
16551655
cls,
16561656
pretrained_model_name_or_path,
1657+
qaic_config: Optional[dict] = None,
16571658
*args,
16581659
**kwargs,
16591660
):
@@ -1684,7 +1685,6 @@ def from_pretrained(
16841685
logger.warning("Updating low_cpu_mem_usage=False")
16851686

16861687
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
1687-
ccl_enabled = kwargs.pop("ccl_enabled", None)
16881688

16891689
from transformers import AutoConfig
16901690

@@ -1696,7 +1696,7 @@ def from_pretrained(
16961696
return cls(
16971697
model,
16981698
pretrained_model_name_or_path=pretrained_model_name_or_path,
1699-
ccl_enabled=ccl_enabled,
1699+
qaic_config=qaic_config,
17001700
**kwargs,
17011701
)
17021702

@@ -1823,8 +1823,9 @@ def compile(
18231823

18241824
# For supporting VLLM and Disaggregated with CCL
18251825
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1826-
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1827-
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
1826+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1827+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
1828+
)
18281829

18291830
# Get specializations from modelling file
18301831
# TODO: expose this via the auto class as well
@@ -2207,7 +2208,7 @@ def __new__(
22072208
model: nn.Module,
22082209
kv_offload: Optional[bool] = True,
22092210
continuous_batching: bool = False,
2210-
ccl_enabled: bool = False,
2211+
qaic_config: Optional[dict] = None,
22112212
**kwargs,
22122213
):
22132214
"""
@@ -2231,10 +2232,10 @@ def __new__(
22312232
"""
22322233
if kv_offload:
22332234
return _QEffAutoModelForImageTextToTextDualQPC(
2234-
model, continuous_batching, ccl_enabled=ccl_enabled, **kwargs
2235+
model, continuous_batching, qaic_config=qaic_config, **kwargs
22352236
)
22362237
else:
2237-
return _QEFFAutoModelForImageTextToTextSingleQPC(model, ccl_enabled=ccl_enabled, **kwargs)
2238+
return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs)
22382239

22392240
@classmethod
22402241
@with_replaced_quantizers
@@ -2284,15 +2285,14 @@ def from_pretrained(
22842285
logger.warning("Updating low_cpu_mem_usage=False")
22852286

22862287
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
2287-
ccl_enabled = kwargs.pop("ccl_enabled", None)
22882288

22892289
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
22902290
return cls(
22912291
model,
22922292
kv_offload=kv_offload,
22932293
continuous_batching=continuous_batching,
22942294
pretrained_model_name_or_path=pretrained_model_name_or_path,
2295-
ccl_enabled=ccl_enabled,
2295+
qaic_config=qaic_config,
22962296
**kwargs,
22972297
)
22982298

@@ -2345,7 +2345,6 @@ def __init__(
23452345
model: nn.Module,
23462346
continuous_batching: bool = False,
23472347
qaic_config: Optional[dict] = None,
2348-
ccl_enabled: bool = False,
23492348
**kwargs,
23502349
):
23512350
"""
@@ -2400,7 +2399,7 @@ def __init__(
24002399
self.is_tlm = transformed
24012400

24022401
self.hash_params["qeff_auto_class"] = self.__class__.__name__
2403-
self.ccl_enabled = ccl_enabled
2402+
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
24042403
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
24052404

24062405
# ---Sampling---
@@ -2494,7 +2493,6 @@ def from_pretrained(
24942493
logger.warning("Updating low_cpu_mem_usage=False")
24952494

24962495
kv_offload = kwargs.pop("kv_offload", None)
2497-
ccl_enabled = kwargs.pop("ccl_enabled", None)
24982496

24992497
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
25002498
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
@@ -2508,15 +2506,14 @@ def from_pretrained(
25082506
model,
25092507
kv_offload=kv_offload,
25102508
pretrained_model_name_or_path=pretrained_model_name_or_path,
2511-
ccl_enabled=ccl_enabled,
2509+
qaic_config=qaic_config,
25122510
**kwargs,
25132511
)
25142512
return cls(
25152513
model,
25162514
continuous_batching=continuous_batching,
25172515
qaic_config=qaic_config,
25182516
pretrained_model_name_or_path=pretrained_model_name_or_path,
2519-
ccl_enabled=ccl_enabled,
25202517
**kwargs,
25212518
)
25222519

@@ -2964,6 +2961,9 @@ def compile(
29642961
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
29652962
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
29662963

2964+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
2965+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len
2966+
)
29672967
# --- Validation ---
29682968
if prefill_only is not None and not isinstance(prefill_only, bool):
29692969
raise TypeError("`prefill_only` must be a boolean.")

examples/performance/compute_context_length/basic_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def main():
117117
model = QEFFAutoModelForCausalLM.from_pretrained(
118118
args.model_name,
119119
continuous_batching=args.continuous_batching,
120-
ccl_enabled=args.ccl_enabled,
120+
qaic_config={
121+
"ccl_enabled":args.ccl_enabled,
122+
},
121123
)
122124

123125
# Compile the model

examples/performance/compute_context_length/gemma3.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@
3838
model_id,
3939
config=config,
4040
attn_implementation="eager",
41-
kv_offload=True,
42-
ccl_enabled=True,
41+
kv_offload=False,
42+
qaic_config={
43+
"ccl_enabled":True,
44+
},
4345
)
4446

4547
### use skip_vision=True, if want to run only text, or false ###
@@ -58,7 +60,7 @@
5860
aic_enable_depth_first=True,
5961
skip_vision=True,
6062
mos=1,
61-
node_precision_info="examples/performance/compute_context_length/gemma3/fp32_nodes_gemma3_4b.yaml",
63+
node_precision_info="examples/performance/compute_context_length/fp32_nodes_gemma3_4b.yaml",
6264
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
6365
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
6466
)
@@ -96,7 +98,7 @@
9698
mxint8_kv_cache=False,
9799
aic_enable_depth_first=True,
98100
mos=1,
99-
node_precision_info="examples/performance/compute_context_length/gemma3/fp32_nodes_gemma3_4b.yaml",
101+
node_precision_info="examples/performance/compute_context_length/fp32_nodes_gemma3_4b.yaml",
100102
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
101103
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
102104
)

examples/performance/compute_context_length/gpt_oss.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121

2222
ctx_len = 4096
2323
# In moe models like gpt-oss, since prefill_seq_len=1 both comp_ctx_lengths_prefill and comp_ctx_lengths_decode can share similar lists.
24-
# Set the list of ccl during prefilling process
25-
comp_ctx_lengths_prefill = [512, ctx_len]
26-
# Set the list of ccl during decoding process
27-
comp_ctx_lengths_decode = [512, ctx_len]
28-
24+
# Set the list of ccl during prefilling and decoding processes
25+
comp_ctx_lengths_prefill = comp_ctx_lengths_decode = [1024, ctx_len]
2926

3027
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(
3128
model_id,
32-
ccl_enabled=True,
29+
qaic_config={
30+
"ccl_enabled":True,
31+
},
3332
)
3433
tokenizer = AutoTokenizer.from_pretrained(model_id)
3534

examples/performance/compute_context_length/granite_vision.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def run_model(
4141
model_name,
4242
token=token,
4343
kv_offload=kv_offload,
44-
ccl_enabled=ccl_enabled,
44+
qaic_config={
45+
"ccl_enabled":ccl_enabled,
46+
},
4547
)
4648

4749
## STEP - 2 Export & Compile the Model

examples/performance/compute_context_length/internvl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ def run_intern_on_aic(
188188
model_name,
189189
kv_offload=kv_offload,
190190
trust_remote_code=True,
191-
ccl_enabled=ccl_enabled,
191+
qaic_config={
192+
"ccl_enabled":ccl_enabled,
193+
},
192194
)
193195

194196
## STEP 2 -- EXPORT & COMPILE THE MODEL

examples/performance/compute_context_length/llama4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
attn_implementation="eager",
3737
kv_offload=True,
3838
config=config,
39-
ccl_enabled=True,
39+
qaic_config={
40+
"ccl_enabled":True,
41+
},
4042
)
4143
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
4244
processor = AutoProcessor.from_pretrained(model_id)

examples/performance/compute_context_length/llama4_cb.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
kv_offload=True,
4242
config=config,
4343
continuous_batching=True,
44-
ccl_enabled=True,
44+
qaic_config={
45+
"ccl_enabled":True,
46+
},
4547
)
4648

4749
qeff_model.compile(
@@ -66,7 +68,9 @@
6668
attn_implementation="eager",
6769
kv_offload=True,
6870
config=config,
69-
ccl_enabled=True,
71+
qaic_config={
72+
"ccl_enabled":True,
73+
},
7074
)
7175

7276
qeff_model.compile(

examples/performance/compute_context_length/llama4_multi_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
attn_implementation="eager",
3737
kv_offload=True,
3838
config=config,
39-
ccl_enabled=True,
39+
qaic_config={
40+
"ccl_enabled":True,
41+
},
4042
)
4143
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
4244
processor = AutoProcessor.from_pretrained(model_id)

examples/performance/compute_context_length/mistral3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def run_model(
4646
model_name,
4747
kv_offload=kv_offload,
4848
config=config,
49-
ccl_enabled=ccl_enabled,
49+
qaic_config={
50+
"ccl_enabled":ccl_enabled,
51+
},
5052
)
5153

5254
## STEP - 2 Export & Compile the Model

0 commit comments

Comments
 (0)