|
14 | 14 | import yaml |
15 | 15 |
|
16 | 16 | from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl |
17 | | -from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig |
| 17 | +from llama_stack.core.datatypes import Provider, QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig |
18 | 18 | from llama_stack.core.distribution import get_provider_registry |
19 | 19 | from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl |
20 | 20 | from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl |
@@ -145,61 +145,67 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig |
145 | 145 | return |
146 | 146 |
|
147 | 147 | # Validate default embedding model |
148 | | - default_embedding_model = vector_stores_config.default_embedding_model |
149 | | - if default_embedding_model is not None: |
150 | | - provider_id = default_embedding_model.provider_id |
151 | | - model_id = default_embedding_model.model_id |
152 | | - default_model_id = f"{provider_id}/{model_id}" |
| 148 | + if vector_stores_config.default_embedding_model is not None: |
| 149 | + await _validate_embedding_model(vector_stores_config.default_embedding_model, impls) |
153 | 150 |
|
154 | | - if Api.models not in impls: |
155 | | - raise ValueError( |
156 | | - f"Models API is not available but vector_stores config requires model '{default_model_id}'" |
157 | | - ) |
| 151 | + # Validate default rewrite query model |
| 152 | + if vector_stores_config.rewrite_query_params and vector_stores_config.rewrite_query_params.model: |
| 153 | + await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls) |
158 | 154 |
|
159 | | - models_impl = impls[Api.models] |
160 | | - response = await models_impl.list_models() |
161 | | - models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"} |
162 | 155 |
|
163 | | - default_model = models_list.get(default_model_id) |
164 | | - if default_model is None: |
165 | | - raise ValueError( |
166 | | - f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}" |
167 | | - ) |
| 156 | +async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None: |
| 157 | + """Validate that an embedding model exists and has required metadata.""" |
| 158 | + provider_id = embedding_model.provider_id |
| 159 | + model_id = embedding_model.model_id |
| 160 | + model_identifier = f"{provider_id}/{model_id}" |
168 | 161 |
|
169 | | - embedding_dimension = default_model.metadata.get("embedding_dimension") |
170 | | - if embedding_dimension is None: |
171 | | - raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata") |
| 162 | + if Api.models not in impls: |
| 163 | + raise ValueError(f"Models API is not available but vector_stores config requires model '{model_identifier}'") |
172 | 164 |
|
173 | | - try: |
174 | | - int(embedding_dimension) |
175 | | - except ValueError as err: |
176 | | - raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err |
| 165 | + models_impl = impls[Api.models] |
| 166 | + response = await models_impl.list_models() |
| 167 | + models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"} |
177 | 168 |
|
178 | | - logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})") |
| 169 | + model = models_list.get(model_identifier) |
| 170 | + if model is None: |
| 171 | + raise ValueError( |
| 172 | + f"Embedding model '{model_identifier}' not found. Available embedding models: {list(models_list.keys())}" |
| 173 | + ) |
179 | 174 |
|
180 | | - # Validate default query expansion model |
181 | | - default_query_expansion_model = vector_stores_config.default_query_expansion_model |
182 | | - if default_query_expansion_model is not None: |
183 | | - provider_id = default_query_expansion_model.provider_id |
184 | | - model_id = default_query_expansion_model.model_id |
185 | | - query_model_id = f"{provider_id}/{model_id}" |
| 175 | + embedding_dimension = model.metadata.get("embedding_dimension") |
| 176 | + if embedding_dimension is None: |
| 177 | + raise ValueError(f"Embedding model '{model_identifier}' is missing 'embedding_dimension' in metadata") |
186 | 178 |
|
187 | | - if Api.models not in impls: |
188 | | - raise ValueError( |
189 | | - f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'" |
190 | | - ) |
| 179 | + try: |
| 180 | + int(embedding_dimension) |
| 181 | + except ValueError as err: |
| 182 | + raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err |
191 | 183 |
|
192 | | - models_impl = impls[Api.models] |
193 | | - response = await models_impl.list_models() |
194 | | - llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"} |
| 184 | + logger.debug(f"Validated embedding model: {model_identifier} (dimension: {embedding_dimension})") |
195 | 185 |
|
196 | | - query_expansion_model = llm_models_list.get(query_model_id) |
197 | | - if query_expansion_model is None: |
198 | | - raise ValueError( |
199 | | - f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}" |
200 | | - ) |
201 | 186 |
|
202 | | - logger.debug(f"Validated default query expansion model: {query_model_id}") |
| 187 | +async def _validate_rewrite_query_model(rewrite_query_model: QualifiedModel, impls: dict[Api, Any]) -> None: |
| 188 | + """Validate that a rewrite query model exists and is accessible.""" |
| 189 | + provider_id = rewrite_query_model.provider_id |
| 190 | + model_id = rewrite_query_model.model_id |
| 191 | + model_identifier = f"{provider_id}/{model_id}" |
| 192 | + |
| 193 | + if Api.models not in impls: |
| 194 | + raise ValueError( |
| 195 | + f"Models API is not available but vector_stores config requires rewrite query model '{model_identifier}'" |
| 196 | + ) |
| 197 | + |
| 198 | + models_impl = impls[Api.models] |
| 199 | + response = await models_impl.list_models() |
| 200 | + llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"} |
| 201 | + |
| 202 | + model = llm_models_list.get(model_identifier) |
| 203 | + if model is None: |
| 204 | + raise ValueError( |
| 205 | + f"Rewrite query model '{model_identifier}' not found. Available LLM models: {list(llm_models_list.keys())}" |
| 206 | + ) |
| 207 | + |
| 208 | + logger.debug(f"Validated rewrite query model: {model_identifier}") |
203 | 209 |
|
204 | 210 |
|
205 | 211 | async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]): |
@@ -466,9 +472,9 @@ async def initialize(self): |
466 | 472 | await validate_safety_config(self.run_config.safety, impls) |
467 | 473 |
|
468 | 474 | # Set global query expansion configuration from stack config |
469 | | - from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config |
| 475 | + from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config |
470 | 476 |
|
471 | | - set_default_query_expansion_config(self.run_config.vector_stores) |
| 477 | + set_default_rewrite_query_config(self.run_config.vector_stores) |
472 | 478 |
|
473 | 479 | self.impls = impls |
474 | 480 |
|
|
0 commit comments