Skip to content

Commit 9638a51

Browse files
renaming to query_rewrite, consolidating, and cleaning up validation
Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent ae27fe3 commit 9638a51

File tree

6 files changed

+138
-120
lines changed

6 files changed

+138
-120
lines changed

src/llama_stack/core/datatypes.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,27 @@ class QualifiedModel(BaseModel):
366366
model_id: str
367367

368368

369+
class RewriteQueryParams(BaseModel):
370+
"""Parameters for query rewriting/expansion."""
371+
372+
model: QualifiedModel | None = Field(
373+
default=None,
374+
description="LLM model for query rewriting/expansion in vector search.",
375+
)
376+
prompt: str = Field(
377+
default=DEFAULT_QUERY_EXPANSION_PROMPT,
378+
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
379+
)
380+
max_tokens: int = Field(
381+
default=100,
382+
description="Maximum number of tokens for query expansion responses.",
383+
)
384+
temperature: float = Field(
385+
default=0.3,
386+
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
387+
)
388+
389+
369390
class VectorStoresConfig(BaseModel):
370391
"""Configuration for vector stores in the stack."""
371392

@@ -377,21 +398,9 @@ class VectorStoresConfig(BaseModel):
377398
default=None,
378399
description="Default embedding model configuration for vector stores.",
379400
)
380-
default_query_expansion_model: QualifiedModel | None = Field(
401+
rewrite_query_params: RewriteQueryParams | None = Field(
381402
default=None,
382-
description="Default LLM model for query expansion/rewriting in vector search.",
383-
)
384-
query_expansion_prompt: str = Field(
385-
default=DEFAULT_QUERY_EXPANSION_PROMPT,
386-
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
387-
)
388-
query_expansion_max_tokens: int = Field(
389-
default=100,
390-
description="Maximum number of tokens for query expansion responses.",
391-
)
392-
query_expansion_temperature: float = Field(
393-
default=0.3,
394-
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
403+
description="Parameters for query rewriting/expansion. None disables query rewriting.",
395404
)
396405

397406

src/llama_stack/core/stack.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import yaml
1515

1616
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
17-
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
17+
from llama_stack.core.datatypes import Provider, QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig
1818
from llama_stack.core.distribution import get_provider_registry
1919
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
2020
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
@@ -145,61 +145,67 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
145145
return
146146

147147
# Validate default embedding model
148-
default_embedding_model = vector_stores_config.default_embedding_model
149-
if default_embedding_model is not None:
150-
provider_id = default_embedding_model.provider_id
151-
model_id = default_embedding_model.model_id
152-
default_model_id = f"{provider_id}/{model_id}"
148+
if vector_stores_config.default_embedding_model is not None:
149+
await _validate_embedding_model(vector_stores_config.default_embedding_model, impls)
153150

154-
if Api.models not in impls:
155-
raise ValueError(
156-
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
157-
)
151+
# Validate default rewrite query model
152+
if vector_stores_config.rewrite_query_params and vector_stores_config.rewrite_query_params.model:
153+
await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls)
158154

159-
models_impl = impls[Api.models]
160-
response = await models_impl.list_models()
161-
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
162155

163-
default_model = models_list.get(default_model_id)
164-
if default_model is None:
165-
raise ValueError(
166-
f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}"
167-
)
156+
async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None:
157+
"""Validate that an embedding model exists and has required metadata."""
158+
provider_id = embedding_model.provider_id
159+
model_id = embedding_model.model_id
160+
model_identifier = f"{provider_id}/{model_id}"
168161

169-
embedding_dimension = default_model.metadata.get("embedding_dimension")
170-
if embedding_dimension is None:
171-
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
162+
if Api.models not in impls:
163+
raise ValueError(f"Models API is not available but vector_stores config requires model '{model_identifier}'")
172164

173-
try:
174-
int(embedding_dimension)
175-
except ValueError as err:
176-
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
165+
models_impl = impls[Api.models]
166+
response = await models_impl.list_models()
167+
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
177168

178-
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
169+
model = models_list.get(model_identifier)
170+
if model is None:
171+
raise ValueError(
172+
f"Embedding model '{model_identifier}' not found. Available embedding models: {list(models_list.keys())}"
173+
)
179174

180-
# Validate default query expansion model
181-
default_query_expansion_model = vector_stores_config.default_query_expansion_model
182-
if default_query_expansion_model is not None:
183-
provider_id = default_query_expansion_model.provider_id
184-
model_id = default_query_expansion_model.model_id
185-
query_model_id = f"{provider_id}/{model_id}"
175+
embedding_dimension = model.metadata.get("embedding_dimension")
176+
if embedding_dimension is None:
177+
raise ValueError(f"Embedding model '{model_identifier}' is missing 'embedding_dimension' in metadata")
186178

187-
if Api.models not in impls:
188-
raise ValueError(
189-
f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'"
190-
)
179+
try:
180+
int(embedding_dimension)
181+
except ValueError as err:
182+
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
191183

192-
models_impl = impls[Api.models]
193-
response = await models_impl.list_models()
194-
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
184+
logger.debug(f"Validated embedding model: {model_identifier} (dimension: {embedding_dimension})")
195185

196-
query_expansion_model = llm_models_list.get(query_model_id)
197-
if query_expansion_model is None:
198-
raise ValueError(
199-
f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}"
200-
)
201186

202-
logger.debug(f"Validated default query expansion model: {query_model_id}")
187+
async def _validate_rewrite_query_model(rewrite_query_model: QualifiedModel, impls: dict[Api, Any]) -> None:
188+
"""Validate that a rewrite query model exists and is accessible."""
189+
provider_id = rewrite_query_model.provider_id
190+
model_id = rewrite_query_model.model_id
191+
model_identifier = f"{provider_id}/{model_id}"
192+
193+
if Api.models not in impls:
194+
raise ValueError(
195+
f"Models API is not available but vector_stores config requires rewrite query model '{model_identifier}'"
196+
)
197+
198+
models_impl = impls[Api.models]
199+
response = await models_impl.list_models()
200+
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
201+
202+
model = llm_models_list.get(model_identifier)
203+
if model is None:
204+
raise ValueError(
205+
f"Rewrite query model '{model_identifier}' not found. Available LLM models: {list(llm_models_list.keys())}"
206+
)
207+
208+
logger.debug(f"Validated rewrite query model: {model_identifier}")
203209

204210

205211
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
@@ -466,9 +472,9 @@ async def initialize(self):
466472
await validate_safety_config(self.run_config.safety, impls)
467473

468474
# Set global query expansion configuration from stack config
469-
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
475+
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
470476

471-
set_default_query_expansion_config(self.run_config.vector_stores)
477+
set_default_rewrite_query_config(self.run_config.vector_stores)
472478

473479
self.impls = impls
474480

src/llama_stack/providers/utils/memory/query_expansion_config.py

Lines changed: 0 additions & 37 deletions
This file was deleted.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
8+
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
9+
10+
# Global configuration for query rewriting - set during stack startup
11+
_DEFAULT_REWRITE_QUERY_MODEL: QualifiedModel | None = None
12+
_DEFAULT_REWRITE_QUERY_MAX_TOKENS: int = 100
13+
_DEFAULT_REWRITE_QUERY_TEMPERATURE: float = 0.3
14+
_REWRITE_QUERY_PROMPT_OVERRIDE: str | None = None
15+
16+
17+
def set_default_rewrite_query_config(vector_stores_config: VectorStoresConfig | None):
18+
"""Set the global default query rewriting configuration from stack config."""
19+
global \
20+
_DEFAULT_REWRITE_QUERY_MODEL, \
21+
_REWRITE_QUERY_PROMPT_OVERRIDE, \
22+
_DEFAULT_REWRITE_QUERY_MAX_TOKENS, \
23+
_DEFAULT_REWRITE_QUERY_TEMPERATURE
24+
if vector_stores_config and vector_stores_config.rewrite_query_params:
25+
params = vector_stores_config.rewrite_query_params
26+
_DEFAULT_REWRITE_QUERY_MODEL = params.model
27+
# Only set override if user provided a custom prompt different from default
28+
if params.prompt != DEFAULT_QUERY_EXPANSION_PROMPT:
29+
_REWRITE_QUERY_PROMPT_OVERRIDE = params.prompt
30+
else:
31+
_REWRITE_QUERY_PROMPT_OVERRIDE = None
32+
_DEFAULT_REWRITE_QUERY_MAX_TOKENS = params.max_tokens
33+
_DEFAULT_REWRITE_QUERY_TEMPERATURE = params.temperature
34+
else:
35+
_DEFAULT_REWRITE_QUERY_MODEL = None
36+
_REWRITE_QUERY_PROMPT_OVERRIDE = None
37+
_DEFAULT_REWRITE_QUERY_MAX_TOKENS = 100
38+
_DEFAULT_REWRITE_QUERY_TEMPERATURE = 0.3

src/llama_stack/providers/utils/memory/vector_store.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
log = get_logger(name=__name__, category="providers::utils")
4040

41-
from llama_stack.providers.utils.memory import query_expansion_config
41+
from llama_stack.providers.utils.memory import rewrite_query_config
4242
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
4343

4444

@@ -295,20 +295,20 @@ async def insert_chunks(
295295

296296
async def _rewrite_query_for_file_search(self, query: str) -> str:
297297
"""Rewrite a search query using the globally configured LLM model for better retrieval results."""
298-
if not query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL:
299-
log.debug("No default query expansion model configured, using original query")
298+
if not rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL:
299+
log.debug("No default query rewriting model configured, using original query")
300300
return query
301301

302-
model_id = f"{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.provider_id}/{query_expansion_config._DEFAULT_QUERY_EXPANSION_MODEL.model_id}"
302+
model_id = f"{rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL.provider_id}/{rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL.model_id}"
303303

304304
# Use custom prompt from config if provided, otherwise use built-in default
305305
# Users only need to configure the model - prompt is automatic with optional override
306-
if query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE:
306+
if rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE:
307307
# Custom prompt from config - format if it contains {query} placeholder
308308
prompt = (
309-
query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE.format(query=query)
310-
if "{query}" in query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
311-
else query_expansion_config._QUERY_EXPANSION_PROMPT_OVERRIDE
309+
rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE.format(query=query)
310+
if "{query}" in rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE
311+
else rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE
312312
)
313313
else:
314314
# Use built-in default prompt and format with query
@@ -317,8 +317,8 @@ async def _rewrite_query_for_file_search(self, query: str) -> str:
317317
request = OpenAIChatCompletionRequestWithExtraBody(
318318
model=model_id,
319319
messages=[{"role": "user", "content": prompt}],
320-
max_tokens=query_expansion_config._DEFAULT_QUERY_EXPANSION_MAX_TOKENS,
321-
temperature=query_expansion_config._DEFAULT_QUERY_EXPANSION_TEMPERATURE,
320+
max_tokens=rewrite_query_config._DEFAULT_REWRITE_QUERY_MAX_TOKENS,
321+
temperature=rewrite_query_config._DEFAULT_REWRITE_QUERY_TEMPERATURE,
322322
)
323323

324324
response = await self.inference_api.openai_chat_completion(request)

tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,9 +1236,9 @@ async def test_query_expansion_functionality(vector_io_adapter):
12361236
"""Test query expansion with simplified global configuration approach."""
12371237
from unittest.mock import MagicMock
12381238

1239-
from llama_stack.core.datatypes import QualifiedModel
1239+
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams
12401240
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
1241-
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
1241+
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
12421242
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
12431243
from llama_stack_api import QueryChunksResponse
12441244

@@ -1266,13 +1266,12 @@ async def test_query_expansion_functionality(vector_io_adapter):
12661266

12671267
# Test 1: Query expansion with default prompt (no custom prompt configured)
12681268
mock_vector_stores_config = MagicMock()
1269-
mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama")
1270-
mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt
1271-
mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value
1272-
mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value
1269+
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
1270+
model=QualifiedModel(provider_id="test", model_id="llama"), max_tokens=100, temperature=0.3
1271+
)
12731272

12741273
# Set global config
1275-
set_default_query_expansion_config(mock_vector_stores_config)
1274+
set_default_rewrite_query_config(mock_vector_stores_config)
12761275

12771276
# Mock chat completion for query rewriting
12781277
mock_inference_api.openai_chat_completion = AsyncMock(
@@ -1305,10 +1304,13 @@ async def test_query_expansion_functionality(vector_io_adapter):
13051304
mock_inference_api.reset_mock()
13061305
mock_index.reset_mock()
13071306

1308-
mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}"
1309-
mock_vector_stores_config.query_expansion_max_tokens = 150
1310-
mock_vector_stores_config.query_expansion_temperature = 0.7
1311-
set_default_query_expansion_config(mock_vector_stores_config)
1307+
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
1308+
model=QualifiedModel(provider_id="test", model_id="llama"),
1309+
prompt="Custom prompt for rewriting: {query}",
1310+
max_tokens=150,
1311+
temperature=0.7,
1312+
)
1313+
set_default_rewrite_query_config(mock_vector_stores_config)
13121314

13131315
result = await vector_store_with_index.query_chunks("test query", params)
13141316

@@ -1328,7 +1330,7 @@ async def test_query_expansion_functionality(vector_io_adapter):
13281330
mock_index.reset_mock()
13291331

13301332
# Clear global config
1331-
set_default_query_expansion_config(None)
1333+
set_default_rewrite_query_config(None)
13321334

13331335
params = {"rewrite_query": True, "max_chunks": 5}
13341336
result2 = await vector_store_with_index.query_chunks("test query", params)

0 commit comments

Comments
 (0)