Skip to content

Commit 859f4c2

Browse files
adding query expansion model to vector store config
Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent 5349c33 commit 859f4c2

File tree

3 files changed

+49
-39
lines changed

3 files changed

+49
-39
lines changed

src/llama_stack/core/datatypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,14 @@ class VectorStoresConfig(BaseModel):
376376
default=None,
377377
description="Default embedding model configuration for vector stores.",
378378
)
379+
default_query_expansion_model: QualifiedModel | None = Field(
380+
default=None,
381+
description="Default LLM model for query expansion/rewriting in vector search.",
382+
)
383+
query_expansion_prompt: str = Field(
384+
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
385+
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
386+
)
379387

380388

381389
class SafetyConfig(BaseModel):

src/llama_stack/core/routers/vector_io.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ async def query_chunks(
9999
) -> QueryChunksResponse:
100100
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
101101
provider = await self.routing_table.get_provider_impl(vector_store_id)
102+
103+
# Ensure params dict exists and add vector_stores_config for query rewriting
104+
if params is None:
105+
params = {}
106+
params["vector_stores_config"] = self.vector_stores_config
107+
102108
return await provider.query_chunks(vector_store_id, query, params)
103109

104110
# OpenAI Vector Stores API endpoints

src/llama_stack/providers/utils/memory/vector_store.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from numpy.typing import NDArray
1818
from pydantic import BaseModel
1919

20+
from llama_stack.core.datatypes import VectorStoresConfig
2021
from llama_stack.log import get_logger
2122
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
2223
from llama_stack.providers.utils.inference.prompt_adapter import (
@@ -267,6 +268,7 @@ class VectorStoreWithIndex:
267268
vector_store: VectorStore
268269
index: EmbeddingIndex
269270
inference_api: Api.inference
271+
vector_stores_config: VectorStoresConfig | None = None
270272

271273
async def insert_chunks(
272274
self,
@@ -301,6 +303,11 @@ async def query_chunks(
301303
) -> QueryChunksResponse:
302304
if params is None:
303305
params = {}
306+
307+
# Extract configuration if provided by router
308+
if "vector_stores_config" in params:
309+
self.vector_stores_config = params["vector_stores_config"]
310+
304311
k = params.get("max_chunks", 3)
305312
mode = params.get("mode")
306313
score_threshold = params.get("score_threshold", 0.0)
@@ -350,51 +357,40 @@ async def _rewrite_query_for_search(self, query: str) -> str:
350357
:param query: The original user query
351358
:returns: The rewritten query optimized for vector search
352359
"""
353-
# Get available models and find a suitable chat model
360+
# Check if query expansion model is configured
361+
if not self.vector_stores_config or not self.vector_stores_config.default_query_expansion_model:
362+
raise ValueError("No default_query_expansion_model configured for query rewriting")
363+
364+
# Use the configured model
365+
expansion_model = self.vector_stores_config.default_query_expansion_model
366+
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
367+
368+
# Validate that the model is available and is an LLM
354369
try:
355370
models_response = await self.inference_api.routing_table.list_models()
356371
except Exception as e:
357-
raise RuntimeError(f"Failed to list available models for query rewriting: {e}") from e
358-
359-
chat_model = None
360-
# Look for an LLM model (for chat completion)
361-
# Prefer local or non-cloud providers to avoid credential issues
362-
llm_models = [m for m in models_response.data if m.model_type == ModelType.llm]
363-
364-
# Filter out models that are known to be embedding models (misclassified as LLM)
365-
embedding_model_patterns = ["minilm", "embed", "embedding", "nomic-embed"]
366-
llm_models = [
367-
m for m in llm_models if not any(pattern in m.identifier.lower() for pattern in embedding_model_patterns)
368-
]
369-
370-
# Priority order: ollama (local), then OpenAI, then others
371-
provider_priority = ["ollama", "openai", "gemini", "bedrock"]
372-
373-
for provider in provider_priority:
374-
for model in llm_models:
375-
model_id = model.identifier.lower()
376-
if provider == "ollama" and "ollama/" in model_id:
377-
chat_model = model.identifier
378-
break
379-
elif provider in model_id:
380-
chat_model = model.identifier
381-
break
382-
if chat_model:
372+
raise RuntimeError(f"Failed to list available models for validation: {e}") from e
373+
374+
model_found = False
375+
for model in models_response.data:
376+
if model.identifier == chat_model:
377+
if model.model_type != ModelType.llm:
378+
raise ValueError(
379+
f"Configured query expansion model '{chat_model}' is not an LLM model "
380+
f"(found type: {model.model_type}). Query rewriting requires an LLM model."
381+
)
382+
model_found = True
383383
break
384384

385-
# Fallback: use first available LLM model if no preferred provider found
386-
if not chat_model and llm_models:
387-
chat_model = llm_models[0].identifier
388-
389-
# If no suitable model found, raise an error
390-
if not chat_model:
391-
raise ValueError("No LLM model available for query rewriting")
392-
393-
rewrite_prompt = f"""Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:
394-
395-
{query}
385+
if not model_found:
386+
available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm]
387+
raise ValueError(
388+
f"Configured query expansion model '{chat_model}' is not available. "
389+
f"Available LLM models: {available_llm_models}"
390+
)
396391

397-
Improved query:"""
392+
# Use the configured prompt (has a default value)
393+
rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query)
398394

399395
chat_request = OpenAIChatCompletionRequestWithExtraBody(
400396
model=chat_model,

0 commit comments

Comments
 (0)