@@ -144,35 +144,62 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
144144 if vector_stores_config is None :
145145 return
146146
147+ # Validate default embedding model
147148 default_embedding_model = vector_stores_config .default_embedding_model
148- if default_embedding_model is None :
149- return
149+ if default_embedding_model is not None :
150+ provider_id = default_embedding_model .provider_id
151+ model_id = default_embedding_model .model_id
152+ default_model_id = f"{ provider_id } /{ model_id } "
150153
151- provider_id = default_embedding_model .provider_id
152- model_id = default_embedding_model .model_id
153- default_model_id = f"{ provider_id } /{ model_id } "
154+ if Api .models not in impls :
155+ raise ValueError (
156+ f"Models API is not available but vector_stores config requires model '{ default_model_id } '"
157+ )
154158
155- if Api .models not in impls :
156- raise ValueError (f"Models API is not available but vector_stores config requires model '{ default_model_id } '" )
159+ models_impl = impls [Api .models ]
160+ response = await models_impl .list_models ()
161+ models_list = {m .identifier : m for m in response .data if m .model_type == "embedding" }
157162
158- models_impl = impls [Api .models ]
159- response = await models_impl .list_models ()
160- models_list = {m .identifier : m for m in response .data if m .model_type == "embedding" }
163+ default_model = models_list .get (default_model_id )
164+ if default_model is None :
165+ raise ValueError (
166+ f"Embedding model '{ default_model_id } ' not found. Available embedding models: { models_list } "
167+ )
161168
162- default_model = models_list . get (default_model_id )
163- if default_model is None :
164- raise ValueError (f"Embedding model '{ default_model_id } ' not found. Available embedding models: { models_list } " )
169+ embedding_dimension = default_model . metadata . get ("embedding_dimension" )
170+ if embedding_dimension is None :
171+ raise ValueError (f"Embedding model '{ default_model_id } ' is missing 'embedding_dimension' in metadata " )
165172
166- embedding_dimension = default_model .metadata .get ("embedding_dimension" )
167- if embedding_dimension is None :
168- raise ValueError (f"Embedding model '{ default_model_id } ' is missing 'embedding_dimension' in metadata" )
173+ try :
174+ int (embedding_dimension )
175+ except ValueError as err :
176+ raise ValueError (f"Embedding dimension '{ embedding_dimension } ' cannot be converted to an integer" ) from err
169177
170- try :
171- int (embedding_dimension )
172- except ValueError as err :
173- raise ValueError (f"Embedding dimension '{ embedding_dimension } ' cannot be converted to an integer" ) from err
178+ logger .debug (f"Validated default embedding model: { default_model_id } (dimension: { embedding_dimension } )" )
179+
180+ # Validate default query expansion model
181+ default_query_expansion_model = vector_stores_config .default_query_expansion_model
182+ if default_query_expansion_model is not None :
183+ provider_id = default_query_expansion_model .provider_id
184+ model_id = default_query_expansion_model .model_id
185+ query_model_id = f"{ provider_id } /{ model_id } "
186+
187+ if Api .models not in impls :
188+ raise ValueError (
189+ f"Models API is not available but vector_stores config requires query expansion model '{ query_model_id } '"
190+ )
191+
192+ models_impl = impls [Api .models ]
193+ response = await models_impl .list_models ()
194+ llm_models_list = {m .identifier : m for m in response .data if m .model_type == "llm" }
174195
175- logger .debug (f"Validated default embedding model: { default_model_id } (dimension: { embedding_dimension } )" )
196+ query_expansion_model = llm_models_list .get (query_model_id )
197+ if query_expansion_model is None :
198+ raise ValueError (
199+ f"Query expansion model '{ query_model_id } ' not found. Available LLM models: { list (llm_models_list .keys ())} "
200+ )
201+
202+ logger .debug (f"Validated default query expansion model: { query_model_id } " )
176203
177204
178205async def validate_safety_config (safety_config : SafetyConfig | None , impls : dict [Api , Any ]):
@@ -437,6 +464,12 @@ async def initialize(self):
437464 await refresh_registry_once (impls )
438465 await validate_vector_stores_config (self .run_config .vector_stores , impls )
439466 await validate_safety_config (self .run_config .safety , impls )
467+
468+ # Set global query expansion configuration from stack config
469+ from llama_stack .providers .utils .memory .query_expansion_config import set_default_query_expansion_config
470+
471+ set_default_query_expansion_config (self .run_config .vector_stores )
472+
440473 self .impls = impls
441474
442475 def create_registry_refresh_task (self ):
0 commit comments