3434 RAGDocument ,
3535 VectorStore ,
3636)
37+ from llama_stack_api .inference import (
38+ OpenAIChatCompletionRequestWithExtraBody ,
39+ OpenAIUserMessageParam ,
40+ )
41+ from llama_stack_api .models import ModelType
3742
3843log = get_logger (name = __name__ , category = "providers::utils" )
3944
@@ -318,6 +323,11 @@ async def query_chunks(
318323 reranker_params = {"impact_factor" : k_value }
319324
320325 query_string = interleaved_content_as_str (query )
326+
327+ # Apply query rewriting if enabled
328+ if params .get ("rewrite_query" , False ):
329+ query_string = await self ._rewrite_query_for_search (query_string )
330+
321331 if mode == "keyword" :
322332 return await self .index .query_keyword (query_string , k , score_threshold )
323333
@@ -333,3 +343,78 @@ async def query_chunks(
333343 )
334344 else :
335345 return await self .index .query_vector (query_vector , k , score_threshold )
346+
347+ async def _rewrite_query_for_search (self , query : str ) -> str :
348+ """Rewrite the user query to improve vector search performance.
349+
350+ :param query: The original user query
351+ :returns: The rewritten query optimized for vector search
352+ """
353+ # Get available models and find a suitable chat model
354+ try :
355+ models_response = await self .inference_api .routing_table .list_models ()
356+ except Exception as e :
357+ raise RuntimeError (f"Failed to list available models for query rewriting: { e } " ) from e
358+
359+ chat_model = None
360+ # Look for an LLM model (for chat completion)
361+ # Prefer local or non-cloud providers to avoid credential issues
362+ llm_models = [m for m in models_response .data if m .model_type == ModelType .llm ]
363+
364+ # Filter out models that are known to be embedding models (misclassified as LLM)
365+ embedding_model_patterns = ["minilm" , "embed" , "embedding" , "nomic-embed" ]
366+ llm_models = [
367+ m for m in llm_models if not any (pattern in m .identifier .lower () for pattern in embedding_model_patterns )
368+ ]
369+
370+ # Priority order: ollama (local), then OpenAI, then others
371+ provider_priority = ["ollama" , "openai" , "gemini" , "bedrock" ]
372+
373+ for provider in provider_priority :
374+ for model in llm_models :
375+ model_id = model .identifier .lower ()
376+ if provider == "ollama" and "ollama/" in model_id :
377+ chat_model = model .identifier
378+ break
379+ elif provider in model_id :
380+ chat_model = model .identifier
381+ break
382+ if chat_model :
383+ break
384+
385+ # Fallback: use first available LLM model if no preferred provider found
386+ if not chat_model and llm_models :
387+ chat_model = llm_models [0 ].identifier
388+
389+ # If no suitable model found, raise an error
390+ if not chat_model :
391+ raise ValueError ("No LLM model available for query rewriting" )
392+
393+ rewrite_prompt = f"""Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:
394+
395+ { query }
396+
397+ Improved query:"""
398+
399+ chat_request = OpenAIChatCompletionRequestWithExtraBody (
400+ model = chat_model ,
401+ messages = [
402+ OpenAIUserMessageParam (
403+ role = "user" ,
404+ content = rewrite_prompt ,
405+ )
406+ ],
407+ max_tokens = 100 ,
408+ )
409+
410+ try :
411+ response = await self .inference_api .openai_chat_completion (chat_request )
412+ except Exception as e :
413+ raise RuntimeError (f"Failed to generate rewritten query: { e } " ) from e
414+
415+ if response .choices and len (response .choices ) > 0 :
416+ rewritten_query = response .choices [0 ].message .content .strip ()
417+ log .info (f"Query rewritten: '{ query } ' → '{ rewritten_query } '" )
418+ return rewritten_query
419+ else :
420+ raise RuntimeError ("No response received from LLM model for query rewriting" )
0 commit comments