Skip to content

Commit ae27fe3

Browse files
refactor to only configuration of model at build time
Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent edb22e9 commit ae27fe3

File tree

31 files changed

+279
-314
lines changed

31 files changed

+279
-314
lines changed

src/llama_stack/core/datatypes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
StorageConfig,
1919
)
2020
from llama_stack.log import LoggingConfig
21+
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
2122
from llama_stack_api import (
2223
Api,
2324
Benchmark,
@@ -381,9 +382,17 @@ class VectorStoresConfig(BaseModel):
381382
description="Default LLM model for query expansion/rewriting in vector search.",
382383
)
383384
query_expansion_prompt: str = Field(
384-
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
385+
default=DEFAULT_QUERY_EXPANSION_PROMPT,
385386
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
386387
)
388+
query_expansion_max_tokens: int = Field(
389+
default=100,
390+
description="Maximum number of tokens for query expansion responses.",
391+
)
392+
query_expansion_temperature: float = Field(
393+
default=0.3,
394+
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
395+
)
387396

388397

389398
class SafetyConfig(BaseModel):

src/llama_stack/core/resolver.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -374,13 +374,6 @@ async def instantiate_provider(
374374
method = "get_adapter_impl"
375375
args = [config, deps]
376376

377-
# Add vector_stores_config for vector_io providers
378-
if (
379-
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
380-
and provider_spec.api == Api.vector_io
381-
):
382-
args.append(run_config.vector_stores)
383-
384377
elif isinstance(provider_spec, AutoRoutedProviderSpec):
385378
method = "get_auto_router_impl"
386379

@@ -401,11 +394,6 @@ async def instantiate_provider(
401394
args.append(policy)
402395
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
403396
args.append(run_config.telemetry.enabled)
404-
if (
405-
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
406-
and provider_spec.api == Api.vector_io
407-
):
408-
args.append(run_config.vector_stores)
409397

410398
fn = getattr(module, method)
411399
impl = await fn(*args)

src/llama_stack/core/routers/vector_io.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,6 @@ async def query_chunks(
9999
) -> QueryChunksResponse:
100100
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
101101
provider = await self.routing_table.get_provider_impl(vector_store_id)
102-
103-
# Ensure params dict exists and add vector_stores_config for query rewriting
104-
if params is None:
105-
params = {}
106-
107-
logger.debug(f"Router vector_stores_config: {self.vector_stores_config}")
108-
if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"):
109-
logger.debug(
110-
f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
111-
)
112-
113-
params["vector_stores_config"] = self.vector_stores_config
114-
115102
return await provider.query_chunks(vector_store_id, query, params)
116103

117104
# OpenAI Vector Stores API endpoints

src/llama_stack/core/stack.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -144,35 +144,62 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
144144
if vector_stores_config is None:
145145
return
146146

147+
# Validate default embedding model
147148
default_embedding_model = vector_stores_config.default_embedding_model
148-
if default_embedding_model is None:
149-
return
149+
if default_embedding_model is not None:
150+
provider_id = default_embedding_model.provider_id
151+
model_id = default_embedding_model.model_id
152+
default_model_id = f"{provider_id}/{model_id}"
150153

151-
provider_id = default_embedding_model.provider_id
152-
model_id = default_embedding_model.model_id
153-
default_model_id = f"{provider_id}/{model_id}"
154+
if Api.models not in impls:
155+
raise ValueError(
156+
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
157+
)
154158

155-
if Api.models not in impls:
156-
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
159+
models_impl = impls[Api.models]
160+
response = await models_impl.list_models()
161+
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
157162

158-
models_impl = impls[Api.models]
159-
response = await models_impl.list_models()
160-
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
163+
default_model = models_list.get(default_model_id)
164+
if default_model is None:
165+
raise ValueError(
166+
f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}"
167+
)
161168

162-
default_model = models_list.get(default_model_id)
163-
if default_model is None:
164-
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
169+
embedding_dimension = default_model.metadata.get("embedding_dimension")
170+
if embedding_dimension is None:
171+
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
165172

166-
embedding_dimension = default_model.metadata.get("embedding_dimension")
167-
if embedding_dimension is None:
168-
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
173+
try:
174+
int(embedding_dimension)
175+
except ValueError as err:
176+
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
169177

170-
try:
171-
int(embedding_dimension)
172-
except ValueError as err:
173-
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
178+
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
179+
180+
# Validate default query expansion model
181+
default_query_expansion_model = vector_stores_config.default_query_expansion_model
182+
if default_query_expansion_model is not None:
183+
provider_id = default_query_expansion_model.provider_id
184+
model_id = default_query_expansion_model.model_id
185+
query_model_id = f"{provider_id}/{model_id}"
186+
187+
if Api.models not in impls:
188+
raise ValueError(
189+
f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'"
190+
)
191+
192+
models_impl = impls[Api.models]
193+
response = await models_impl.list_models()
194+
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
174195

175-
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
196+
query_expansion_model = llm_models_list.get(query_model_id)
197+
if query_expansion_model is None:
198+
raise ValueError(
199+
f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}"
200+
)
201+
202+
logger.debug(f"Validated default query expansion model: {query_model_id}")
176203

177204

178205
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
@@ -437,6 +464,12 @@ async def initialize(self):
437464
await refresh_registry_once(impls)
438465
await validate_vector_stores_config(self.run_config.vector_stores, impls)
439466
await validate_safety_config(self.run_config.safety, impls)
467+
468+
# Set global query expansion configuration from stack config
469+
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
470+
471+
set_default_query_expansion_config(self.run_config.vector_stores)
472+
440473
self.impls = impls
441474

442475
def create_registry_refresh_task(self):

src/llama_stack/distributions/ci-tests/run-with-postgres-store.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,5 +296,7 @@ vector_stores:
296296
297297
298298
Improved query:'
299+
query_expansion_max_tokens: 100
300+
query_expansion_temperature: 0.3
299301
safety:
300302
default_shield_id: llama-guard

src/llama_stack/distributions/ci-tests/run.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,5 +287,7 @@ vector_stores:
287287
288288
289289
Improved query:'
290+
query_expansion_max_tokens: 100
291+
query_expansion_temperature: 0.3
290292
safety:
291293
default_shield_id: llama-guard

src/llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,5 +299,7 @@ vector_stores:
299299
300300
301301
Improved query:'
302+
query_expansion_max_tokens: 100
303+
query_expansion_temperature: 0.3
302304
safety:
303305
default_shield_id: llama-guard

src/llama_stack/distributions/starter-gpu/run.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,5 +290,7 @@ vector_stores:
290290
291291
292292
Improved query:'
293+
query_expansion_max_tokens: 100
294+
query_expansion_temperature: 0.3
293295
safety:
294296
default_shield_id: llama-guard

src/llama_stack/distributions/starter/run-with-postgres-store.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,5 +296,7 @@ vector_stores:
296296
297297
298298
Improved query:'
299+
query_expansion_max_tokens: 100
300+
query_expansion_temperature: 0.3
299301
safety:
300302
default_shield_id: llama-guard

src/llama_stack/distributions/starter/run.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,5 +287,7 @@ vector_stores:
287287
288288
289289
Improved query:'
290+
query_expansion_max_tokens: 100
291+
query_expansion_temperature: 0.3
290292
safety:
291293
default_shield_id: llama-guard

0 commit comments

Comments
 (0)