Skip to content

Commit 31e28b6

Browse files
renaming to query_rewrite, consolidating, and cleaning up validation
Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent d887f1f commit 31e28b6

File tree

12 files changed

+138
-180
lines changed

12 files changed

+138
-180
lines changed

src/llama_stack/core/datatypes.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,27 @@ class QualifiedModel(BaseModel):
366366
model_id: str
367367

368368

369+
class RewriteQueryParams(BaseModel):
370+
"""Parameters for query rewriting/expansion."""
371+
372+
model: QualifiedModel | None = Field(
373+
default=None,
374+
description="LLM model for query rewriting/expansion in vector search.",
375+
)
376+
prompt: str = Field(
377+
default=DEFAULT_QUERY_EXPANSION_PROMPT,
378+
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
379+
)
380+
max_tokens: int = Field(
381+
default=100,
382+
description="Maximum number of tokens for query expansion responses.",
383+
)
384+
temperature: float = Field(
385+
default=0.3,
386+
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
387+
)
388+
389+
369390
class VectorStoresConfig(BaseModel):
370391
"""Configuration for vector stores in the stack."""
371392

@@ -377,21 +398,9 @@ class VectorStoresConfig(BaseModel):
377398
default=None,
378399
description="Default embedding model configuration for vector stores.",
379400
)
380-
default_query_expansion_model: QualifiedModel | None = Field(
401+
rewrite_query_params: RewriteQueryParams | None = Field(
381402
default=None,
382-
description="Default LLM model for query expansion/rewriting in vector search.",
383-
)
384-
query_expansion_prompt: str = Field(
385-
default=DEFAULT_QUERY_EXPANSION_PROMPT,
386-
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
387-
)
388-
query_expansion_max_tokens: int = Field(
389-
default=100,
390-
description="Maximum number of tokens for query expansion responses.",
391-
)
392-
query_expansion_temperature: float = Field(
393-
default=0.3,
394-
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
403+
description="Parameters for query rewriting/expansion. None disables query rewriting.",
395404
)
396405

397406

src/llama_stack/core/stack.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import yaml
1515

1616
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
17-
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
17+
from llama_stack.core.datatypes import Provider, QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig
1818
from llama_stack.core.distribution import get_provider_registry
1919
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
2020
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
@@ -145,61 +145,67 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
145145
return
146146

147147
# Validate default embedding model
148-
default_embedding_model = vector_stores_config.default_embedding_model
149-
if default_embedding_model is not None:
150-
provider_id = default_embedding_model.provider_id
151-
model_id = default_embedding_model.model_id
152-
default_model_id = f"{provider_id}/{model_id}"
148+
if vector_stores_config.default_embedding_model is not None:
149+
await _validate_embedding_model(vector_stores_config.default_embedding_model, impls)
153150

154-
if Api.models not in impls:
155-
raise ValueError(
156-
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
157-
)
151+
# Validate default rewrite query model
152+
if vector_stores_config.rewrite_query_params and vector_stores_config.rewrite_query_params.model:
153+
await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls)
158154

159-
models_impl = impls[Api.models]
160-
response = await models_impl.list_models()
161-
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
162155

163-
default_model = models_list.get(default_model_id)
164-
if default_model is None:
165-
raise ValueError(
166-
f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}"
167-
)
156+
async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None:
157+
"""Validate that an embedding model exists and has required metadata."""
158+
provider_id = embedding_model.provider_id
159+
model_id = embedding_model.model_id
160+
model_identifier = f"{provider_id}/{model_id}"
168161

169-
embedding_dimension = default_model.metadata.get("embedding_dimension")
170-
if embedding_dimension is None:
171-
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
162+
if Api.models not in impls:
163+
raise ValueError(f"Models API is not available but vector_stores config requires model '{model_identifier}'")
172164

173-
try:
174-
int(embedding_dimension)
175-
except ValueError as err:
176-
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
165+
models_impl = impls[Api.models]
166+
response = await models_impl.list_models()
167+
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
177168

178-
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
169+
model = models_list.get(model_identifier)
170+
if model is None:
171+
raise ValueError(
172+
f"Embedding model '{model_identifier}' not found. Available embedding models: {list(models_list.keys())}"
173+
)
179174

180-
# Validate default query expansion model
181-
default_query_expansion_model = vector_stores_config.default_query_expansion_model
182-
if default_query_expansion_model is not None:
183-
provider_id = default_query_expansion_model.provider_id
184-
model_id = default_query_expansion_model.model_id
185-
query_model_id = f"{provider_id}/{model_id}"
175+
embedding_dimension = model.metadata.get("embedding_dimension")
176+
if embedding_dimension is None:
177+
raise ValueError(f"Embedding model '{model_identifier}' is missing 'embedding_dimension' in metadata")
186178

187-
if Api.models not in impls:
188-
raise ValueError(
189-
f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'"
190-
)
179+
try:
180+
int(embedding_dimension)
181+
except ValueError as err:
182+
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
191183

192-
models_impl = impls[Api.models]
193-
response = await models_impl.list_models()
194-
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
184+
logger.debug(f"Validated embedding model: {model_identifier} (dimension: {embedding_dimension})")
195185

196-
query_expansion_model = llm_models_list.get(query_model_id)
197-
if query_expansion_model is None:
198-
raise ValueError(
199-
f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}"
200-
)
201186

202-
logger.debug(f"Validated default query expansion model: {query_model_id}")
187+
async def _validate_rewrite_query_model(rewrite_query_model: QualifiedModel, impls: dict[Api, Any]) -> None:
188+
"""Validate that a rewrite query model exists and is accessible."""
189+
provider_id = rewrite_query_model.provider_id
190+
model_id = rewrite_query_model.model_id
191+
model_identifier = f"{provider_id}/{model_id}"
192+
193+
if Api.models not in impls:
194+
raise ValueError(
195+
f"Models API is not available but vector_stores config requires rewrite query model '{model_identifier}'"
196+
)
197+
198+
models_impl = impls[Api.models]
199+
response = await models_impl.list_models()
200+
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
201+
202+
model = llm_models_list.get(model_identifier)
203+
if model is None:
204+
raise ValueError(
205+
f"Rewrite query model '{model_identifier}' not found. Available LLM models: {list(llm_models_list.keys())}"
206+
)
207+
208+
logger.debug(f"Validated rewrite query model: {model_identifier}")
203209

204210

205211
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
@@ -466,9 +472,9 @@ async def initialize(self):
466472
await validate_safety_config(self.run_config.safety, impls)
467473

468474
# Set global query expansion configuration from stack config
469-
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
475+
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
470476

471-
set_default_query_expansion_config(self.run_config.vector_stores)
477+
set_default_rewrite_query_config(self.run_config.vector_stores)
472478

473479
self.impls = impls
474480

src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,5 @@ vector_stores:
288288
default_embedding_model:
289289
provider_id: sentence-transformers
290290
model_id: nomic-ai/nomic-embed-text-v1.5
291-
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
292-
Return only the improved query, no explanations:
293-
294-
295-
{query}
296-
297-
298-
Improved query:'
299-
query_expansion_max_tokens: 100
300-
query_expansion_temperature: 0.3
301291
safety:
302292
default_shield_id: llama-guard

src/llama_stack/distributions/ci-tests/run.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,5 @@ vector_stores:
279279
default_embedding_model:
280280
provider_id: sentence-transformers
281281
model_id: nomic-ai/nomic-embed-text-v1.5
282-
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
283-
Return only the improved query, no explanations:
284-
285-
286-
{query}
287-
288-
289-
Improved query:'
290-
query_expansion_max_tokens: 100
291-
query_expansion_temperature: 0.3
292282
safety:
293283
default_shield_id: llama-guard

src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -291,15 +291,5 @@ vector_stores:
291291
default_embedding_model:
292292
provider_id: sentence-transformers
293293
model_id: nomic-ai/nomic-embed-text-v1.5
294-
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
295-
Return only the improved query, no explanations:
296-
297-
298-
{query}
299-
300-
301-
Improved query:'
302-
query_expansion_max_tokens: 100
303-
query_expansion_temperature: 0.3
304294
safety:
305295
default_shield_id: llama-guard

src/llama_stack/distributions/starter-gpu/run.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -282,15 +282,5 @@ vector_stores:
282282
default_embedding_model:
283283
provider_id: sentence-transformers
284284
model_id: nomic-ai/nomic-embed-text-v1.5
285-
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
286-
Return only the improved query, no explanations:
287-
288-
289-
{query}
290-
291-
292-
Improved query:'
293-
query_expansion_max_tokens: 100
294-
query_expansion_temperature: 0.3
295285
safety:
296286
default_shield_id: llama-guard

src/llama_stack/distributions/starter/run-with-postgres-store.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,5 @@ vector_stores:
288288
default_embedding_model:
289289
provider_id: sentence-transformers
290290
model_id: nomic-ai/nomic-embed-text-v1.5
291-
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
292-
Return only the improved query, no explanations:
293-
294-
295-
{query}
296-
297-
298-
Improved query:'
299-
query_expansion_max_tokens: 100
300-
query_expansion_temperature: 0.3
301291
safety:
302292
default_shield_id: llama-guard

src/llama_stack/distributions/starter/run.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,5 @@ vector_stores:
279279
default_embedding_model:
280280
provider_id: sentence-transformers
281281
model_id: nomic-ai/nomic-embed-text-v1.5
282-
query_expansion_prompt: 'Expand this query with relevant synonyms and related terms.
283-
Return only the improved query, no explanations:
284-
285-
286-
{query}
287-
288-
289-
Improved query:'
290-
query_expansion_max_tokens: 100
291-
query_expansion_temperature: 0.3
292282
safety:
293283
default_shield_id: llama-guard

src/llama_stack/providers/utils/memory/query_expansion_config.py

Lines changed: 0 additions & 37 deletions
This file was deleted.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
8+
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
9+
10+
# Global configuration for query rewriting - set during stack startup
11+
_DEFAULT_REWRITE_QUERY_MODEL: QualifiedModel | None = None
12+
_DEFAULT_REWRITE_QUERY_MAX_TOKENS: int = 100
13+
_DEFAULT_REWRITE_QUERY_TEMPERATURE: float = 0.3
14+
_REWRITE_QUERY_PROMPT_OVERRIDE: str | None = None
15+
16+
17+
def set_default_rewrite_query_config(vector_stores_config: VectorStoresConfig | None):
18+
"""Set the global default query rewriting configuration from stack config."""
19+
global \
20+
_DEFAULT_REWRITE_QUERY_MODEL, \
21+
_REWRITE_QUERY_PROMPT_OVERRIDE, \
22+
_DEFAULT_REWRITE_QUERY_MAX_TOKENS, \
23+
_DEFAULT_REWRITE_QUERY_TEMPERATURE
24+
if vector_stores_config and vector_stores_config.rewrite_query_params:
25+
params = vector_stores_config.rewrite_query_params
26+
_DEFAULT_REWRITE_QUERY_MODEL = params.model
27+
# Only set override if user provided a custom prompt different from default
28+
if params.prompt != DEFAULT_QUERY_EXPANSION_PROMPT:
29+
_REWRITE_QUERY_PROMPT_OVERRIDE = params.prompt
30+
else:
31+
_REWRITE_QUERY_PROMPT_OVERRIDE = None
32+
_DEFAULT_REWRITE_QUERY_MAX_TOKENS = params.max_tokens
33+
_DEFAULT_REWRITE_QUERY_TEMPERATURE = params.temperature
34+
else:
35+
_DEFAULT_REWRITE_QUERY_MODEL = None
36+
_REWRITE_QUERY_PROMPT_OVERRIDE = None
37+
_DEFAULT_REWRITE_QUERY_MAX_TOKENS = 100
38+
_DEFAULT_REWRITE_QUERY_TEMPERATURE = 0.3

0 commit comments

Comments
 (0)