@@ -909,7 +909,7 @@ def __init__(
909909 self ,
910910 model : nn .Module ,
911911 continuous_batching : bool = False ,
912- ccl_enabled : bool = False ,
912+ qaic_config : Optional [ dict ] = None ,
913913 ** kwargs ,
914914 ):
915915 """
@@ -935,7 +935,7 @@ def __init__(
935935 self .vision_model = QEffVisionEncoderForTextImageToTextModel (model , ** kwargs )
936936 self .lang_model = QEffCausalLMForTextImageToTextModel (model , ** kwargs )
937937 self .continuous_batching = continuous_batching
938- self .ccl_enabled = ccl_enabled
938+ self .ccl_enabled = qaic_config . get ( " ccl_enabled" , False )
939939 self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
940940 self .input_shapes , self .output_names = None , None
941941
@@ -955,7 +955,7 @@ def model_name(self) -> str:
955955 return mname
956956
957957 @classmethod
958- def from_pretrained (cls , pretrained_model_name_or_path : str , ** kwargs ):
958+ def from_pretrained (cls , pretrained_model_name_or_path : str , qaic_config : Optional [ dict ] = None , ** kwargs ):
959959 """
960960 Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path.
961961
@@ -980,13 +980,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
980980 logger .warning ("Updating low_cpu_mem_usage=False" )
981981
982982 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
983- ccl_enabled = kwargs .pop ("ccl_enabled" , None )
984983
985984 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
986985 return cls (
987986 model ,
988987 pretrained_model_name_or_path = pretrained_model_name_or_path ,
989- ccl_enabled = ccl_enabled ,
988+ qaic_config = qaic_config ,
990989 ** kwargs ,
991990 )
992991
@@ -1190,8 +1189,9 @@ def compile(
11901189
11911190 # For supporting VLLM and Disaggregated with CCL
11921191 if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None :
1193- self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1194- self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
1192+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1193+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
1194+ )
11951195
11961196 specializations , compiler_options = self .model .get_specializations (
11971197 batch_size = batch_size ,
@@ -1614,7 +1614,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
16141614 def __init__ (
16151615 self ,
16161616 model : nn .Module ,
1617- ccl_enabled : bool = False ,
1617+ qaic_config : Optional [ dict ] = None ,
16181618 ** kwargs ,
16191619 ):
16201620 """
@@ -1647,13 +1647,14 @@ def __init__(
16471647 else :
16481648 self .model .config .use_cache = True
16491649 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
1650- self .ccl_enabled = ccl_enabled
1650+ self .ccl_enabled = qaic_config . get ( " ccl_enabled" , False )
16511651 self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
16521652
16531653 @classmethod
16541654 def from_pretrained (
16551655 cls ,
16561656 pretrained_model_name_or_path ,
1657+ qaic_config : Optional [dict ] = None ,
16571658 * args ,
16581659 ** kwargs ,
16591660 ):
@@ -1684,7 +1685,6 @@ def from_pretrained(
16841685 logger .warning ("Updating low_cpu_mem_usage=False" )
16851686
16861687 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
1687- ccl_enabled = kwargs .pop ("ccl_enabled" , None )
16881688
16891689 from transformers import AutoConfig
16901690
@@ -1696,7 +1696,7 @@ def from_pretrained(
16961696 return cls (
16971697 model ,
16981698 pretrained_model_name_or_path = pretrained_model_name_or_path ,
1699- ccl_enabled = ccl_enabled ,
1699+ qaic_config = qaic_config ,
17001700 ** kwargs ,
17011701 )
17021702
@@ -1823,8 +1823,9 @@ def compile(
18231823
18241824 # For supporting VLLM and Disaggregated with CCL
18251825 if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None :
1826- self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1827- self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
1826+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
1827+ comp_ctx_lengths_prefill , comp_ctx_lengths_decode , ctx_len , prefill_seq_len
1828+ )
18281829
18291830 # Get specializations from modelling file
18301831 # TODO: expose this via the auto class as well
@@ -2207,7 +2208,7 @@ def __new__(
22072208 model : nn .Module ,
22082209 kv_offload : Optional [bool ] = True ,
22092210 continuous_batching : bool = False ,
2210- ccl_enabled : bool = False ,
2211+ qaic_config : Optional [ dict ] = None ,
22112212 ** kwargs ,
22122213 ):
22132214 """
@@ -2231,10 +2232,10 @@ def __new__(
22312232 """
22322233 if kv_offload :
22332234 return _QEffAutoModelForImageTextToTextDualQPC (
2234- model , continuous_batching , ccl_enabled = ccl_enabled , ** kwargs
2235+ model , continuous_batching , qaic_config = qaic_config , ** kwargs
22352236 )
22362237 else :
2237- return _QEFFAutoModelForImageTextToTextSingleQPC (model , ccl_enabled = ccl_enabled , ** kwargs )
2238+ return _QEFFAutoModelForImageTextToTextSingleQPC (model , qaic_config = qaic_config , ** kwargs )
22382239
22392240 @classmethod
22402241 @with_replaced_quantizers
@@ -2284,15 +2285,14 @@ def from_pretrained(
22842285 logger .warning ("Updating low_cpu_mem_usage=False" )
22852286
22862287 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
2287- ccl_enabled = kwargs .pop ("ccl_enabled" , None )
22882288
22892289 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
22902290 return cls (
22912291 model ,
22922292 kv_offload = kv_offload ,
22932293 continuous_batching = continuous_batching ,
22942294 pretrained_model_name_or_path = pretrained_model_name_or_path ,
2295- ccl_enabled = ccl_enabled ,
2295+ qaic_config = qaic_config ,
22962296 ** kwargs ,
22972297 )
22982298
@@ -2345,7 +2345,6 @@ def __init__(
23452345 model : nn .Module ,
23462346 continuous_batching : bool = False ,
23472347 qaic_config : Optional [dict ] = None ,
2348- ccl_enabled : bool = False ,
23492348 ** kwargs ,
23502349 ):
23512350 """
@@ -2400,7 +2399,7 @@ def __init__(
24002399 self .is_tlm = transformed
24012400
24022401 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
2403- self .ccl_enabled = ccl_enabled
2402+ self .ccl_enabled = qaic_config . get ( " ccl_enabled" , False )
24042403 self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = None , None
24052404
24062405 # ---Sampling---
@@ -2494,7 +2493,6 @@ def from_pretrained(
24942493 logger .warning ("Updating low_cpu_mem_usage=False" )
24952494
24962495 kv_offload = kwargs .pop ("kv_offload" , None )
2497- ccl_enabled = kwargs .pop ("ccl_enabled" , None )
24982496
24992497 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
25002498 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , * args , ** kwargs )
@@ -2508,15 +2506,14 @@ def from_pretrained(
25082506 model ,
25092507 kv_offload = kv_offload ,
25102508 pretrained_model_name_or_path = pretrained_model_name_or_path ,
2511- ccl_enabled = ccl_enabled ,
2509+ qaic_config = qaic_config ,
25122510 ** kwargs ,
25132511 )
25142512 return cls (
25152513 model ,
25162514 continuous_batching = continuous_batching ,
25172515 qaic_config = qaic_config ,
25182516 pretrained_model_name_or_path = pretrained_model_name_or_path ,
2519- ccl_enabled = ccl_enabled ,
25202517 ** kwargs ,
25212518 )
25222519
@@ -2964,6 +2961,9 @@ def compile(
29642961 self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
29652962 self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
29662963
2964+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode = process_ccl_specializations (
2965+ self .comp_ctx_lengths_prefill , self .comp_ctx_lengths_decode , ctx_len , prefill_seq_len
2966+ )
29672967 # --- Validation ---
29682968 if prefill_only is not None and not isinstance (prefill_only , bool ):
29692969 raise TypeError ("`prefill_only` must be a boolean." )
0 commit comments