|
17 | 17 | from numpy.typing import NDArray |
18 | 18 | from pydantic import BaseModel |
19 | 19 |
|
| 20 | +from llama_stack.core.datatypes import VectorStoresConfig |
20 | 21 | from llama_stack.log import get_logger |
21 | 22 | from llama_stack.models.llama.llama3.tokenizer import Tokenizer |
22 | 23 | from llama_stack.providers.utils.inference.prompt_adapter import ( |
@@ -267,6 +268,7 @@ class VectorStoreWithIndex: |
267 | 268 | vector_store: VectorStore |
268 | 269 | index: EmbeddingIndex |
269 | 270 | inference_api: Api.inference |
| 271 | + vector_stores_config: VectorStoresConfig | None = None |
270 | 272 |
|
271 | 273 | async def insert_chunks( |
272 | 274 | self, |
@@ -301,6 +303,11 @@ async def query_chunks( |
301 | 303 | ) -> QueryChunksResponse: |
302 | 304 | if params is None: |
303 | 305 | params = {} |
| 306 | + |
| 307 | + # Extract configuration if provided by router |
| 308 | + if "vector_stores_config" in params: |
| 309 | + self.vector_stores_config = params["vector_stores_config"] |
| 310 | + |
304 | 311 | k = params.get("max_chunks", 3) |
305 | 312 | mode = params.get("mode") |
306 | 313 | score_threshold = params.get("score_threshold", 0.0) |
@@ -350,51 +357,40 @@ async def _rewrite_query_for_search(self, query: str) -> str: |
350 | 357 | :param query: The original user query |
351 | 358 | :returns: The rewritten query optimized for vector search |
352 | 359 | """ |
353 | | - # Get available models and find a suitable chat model |
| 360 | + # Check if query expansion model is configured |
| 361 | + if not self.vector_stores_config or not self.vector_stores_config.default_query_expansion_model: |
| 362 | + raise ValueError("No default_query_expansion_model configured for query rewriting") |
| 363 | + |
| 364 | + # Use the configured model |
| 365 | + expansion_model = self.vector_stores_config.default_query_expansion_model |
| 366 | + chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}" |
| 367 | + |
| 368 | + # Validate that the model is available and is an LLM |
354 | 369 | try: |
355 | 370 | models_response = await self.inference_api.routing_table.list_models() |
356 | 371 | except Exception as e: |
357 | | - raise RuntimeError(f"Failed to list available models for query rewriting: {e}") from e |
358 | | - |
359 | | - chat_model = None |
360 | | - # Look for an LLM model (for chat completion) |
361 | | - # Prefer local or non-cloud providers to avoid credential issues |
362 | | - llm_models = [m for m in models_response.data if m.model_type == ModelType.llm] |
363 | | - |
364 | | - # Filter out models that are known to be embedding models (misclassified as LLM) |
365 | | - embedding_model_patterns = ["minilm", "embed", "embedding", "nomic-embed"] |
366 | | - llm_models = [ |
367 | | - m for m in llm_models if not any(pattern in m.identifier.lower() for pattern in embedding_model_patterns) |
368 | | - ] |
369 | | - |
370 | | - # Priority order: ollama (local), then OpenAI, then others |
371 | | - provider_priority = ["ollama", "openai", "gemini", "bedrock"] |
372 | | - |
373 | | - for provider in provider_priority: |
374 | | - for model in llm_models: |
375 | | - model_id = model.identifier.lower() |
376 | | - if provider == "ollama" and "ollama/" in model_id: |
377 | | - chat_model = model.identifier |
378 | | - break |
379 | | - elif provider in model_id: |
380 | | - chat_model = model.identifier |
381 | | - break |
382 | | - if chat_model: |
| 372 | + raise RuntimeError(f"Failed to list available models for validation: {e}") from e |
| 373 | + |
| 374 | + model_found = False |
| 375 | + for model in models_response.data: |
| 376 | + if model.identifier == chat_model: |
| 377 | + if model.model_type != ModelType.llm: |
| 378 | + raise ValueError( |
| 379 | + f"Configured query expansion model '{chat_model}' is not an LLM model " |
| 380 | + f"(found type: {model.model_type}). Query rewriting requires an LLM model." |
| 381 | + ) |
| 382 | + model_found = True |
383 | 383 | break |
384 | 384 |
|
385 | | - # Fallback: use first available LLM model if no preferred provider found |
386 | | - if not chat_model and llm_models: |
387 | | - chat_model = llm_models[0].identifier |
388 | | - |
389 | | - # If no suitable model found, raise an error |
390 | | - if not chat_model: |
391 | | - raise ValueError("No LLM model available for query rewriting") |
392 | | - |
393 | | - rewrite_prompt = f"""Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations: |
394 | | -
|
395 | | -{query} |
| 385 | + if not model_found: |
| 386 | + available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm] |
| 387 | + raise ValueError( |
| 388 | + f"Configured query expansion model '{chat_model}' is not available. " |
| 389 | + f"Available LLM models: {available_llm_models}" |
| 390 | + ) |
396 | 391 |
|
397 | | -Improved query:""" |
| 392 | + # Use the configured prompt (has a default value) |
| 393 | + rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query) |
398 | 394 |
|
399 | 395 | chat_request = OpenAIChatCompletionRequestWithExtraBody( |
400 | 396 | model=chat_model, |
|
0 commit comments