diff --git a/graphrag/cli/main.py b/graphrag/cli/main.py index 610427871f..91031370d4 100644 --- a/graphrag/cli/main.py +++ b/graphrag/cli/main.py @@ -375,6 +375,14 @@ def _prompt_tune_cli( def _query_cli( method: Annotated[SearchMethod, typer.Option(help="The query algorithm to use.")], query: Annotated[str, typer.Option(help="The query to execute.")], + raw_chunks: Annotated[ + bool, + typer.Option( + "--raw-chunks", + help="Show raw chunks retrieved from vector store before final response.", + is_flag=True + ), + ] = False, config: Annotated[ Path | None, typer.Option( @@ -451,6 +459,7 @@ def _query_cli( response_type=response_type, streaming=streaming, query=query, + raw_chunks=raw_chunks, # Added for local search ) case SearchMethod.GLOBAL: run_global_search( @@ -462,6 +471,7 @@ def _query_cli( response_type=response_type, streaming=streaming, query=query, + raw_chunks=raw_chunks, # Added for global search ) case SearchMethod.DRIFT: run_drift_search( @@ -472,6 +482,7 @@ def _query_cli( streaming=streaming, response_type=response_type, query=query, + raw_chunks=raw_chunks, # Added for drift search ) case SearchMethod.BASIC: run_basic_search( diff --git a/graphrag/cli/query.py b/graphrag/cli/query.py index 937ca69bbf..86038f198e 100644 --- a/graphrag/cli/query.py +++ b/graphrag/cli/query.py @@ -21,7 +21,168 @@ logger = PrintProgressLogger("") - +class RawChunksCallback(NoopQueryCallbacks): + def on_context(self, context: Any) -> None: + try: + # For DRIFT search's three-step process + if isinstance(context, dict) and 'initial_context' in context: + print("\n=== DRIFT SEARCH RAW CHUNKS ===") + + # Step 1: Primer Search + print("\nSTEP 1 - PRIMER SEARCH:") + if hasattr(context['initial_context'], 'context_chunks'): + chunks = context['initial_context'].context_chunks + if isinstance(chunks, dict) and 'reports' in chunks: + for i, report in enumerate(chunks['reports'], 1): + print(f"\nReport {i}:") + print(f"Title: {report.get('title', 'N/A')}") + print(f"Text: {report.get('text', 'N/A')}") + else: + print(chunks) + + # Step 2: Follow-up Searches + print("\nSTEP 2 - FOLLOW-UP SEARCHES:") + if 'followup_contexts' in context: + for i, followup in enumerate(context['followup_contexts'], 1): + print(f"\nFollow-up {i}:") + if hasattr(followup, 'query'): + print(f"Question: {followup.query}") + if hasattr(followup, 'context_chunks'): + print("Retrieved Context:") + if isinstance(followup.context_chunks, dict): + for key, value in followup.context_chunks.items(): + print(f"\n{key}: {value}") + else: + print(followup.context_chunks) + + # Step 3: Final Synthesis + print("\nSTEP 3 - FINAL SYNTHESIS:") + if 'final_context' in context and hasattr(context['final_context'], 'context_chunks'): + final_chunks = context['final_context'].context_chunks + if isinstance(final_chunks, dict): + for key, value in final_chunks.items(): + print(f"\n{key}: {value}") + else: + print(final_chunks) + + print("\n=== END DRIFT SEARCH RAW CHUNKS ===\n") + + + # For Global and Local searches + else: + print("\n=== RAW CHUNKS FROM VECTOR STORE ===") + + # First try to access context_records if available + if hasattr(context, 'context_records'): + records = context.context_records + if isinstance(records, dict): + # Handle reports + if 'reports' in records: + print("\nReports:") + for i, report in enumerate(records['reports'], 1): + print(f"\nReport {i}:") + if isinstance(report, dict): + if 'title' in report: + print(f"Title: {report['title']}") + if 'text' in report: + print(f"Text: {report['text']}") + if 'content' in report: + print(f"Content: {report['content']}") + + # Handle text units + if 'text_units' in records: + print("\nText Units:") + for i, unit in enumerate(records['text_units'], 1): + print(f"\nText Unit {i}:") + if isinstance(unit, dict): + if 'text' in unit: + print(f"Text: {unit['text']}") + if 'source' in unit: + print(f"Source: {unit['source']}") + + # Handle relationships + if 'relationships' in records: + print("\nRelationships:") + for i, rel in enumerate(records['relationships'], 1): + print(f"\nRelationship {i}: {rel}") + + # Fallback to direct attributes if context_records not available + else: + # Handle reports + if hasattr(context, 'reports'): + print("\nReports:") + for i, report in enumerate(context.reports, 1): + print(f"\nReport {i}:") + if isinstance(report, dict): + if 'title' in report: + print(f"Title: {report['title']}") + if 'text' in report: + print(f"Text: {report['text']}") + if 'content' in report: + print(f"Content: {report['content']}") + + # Handle text units + if hasattr(context, 'text_units'): + print("\nText Units:") + for i, unit in enumerate(context.text_units, 1): + print(f"\nText Unit {i}:") + if isinstance(unit, dict): + if 'text' in unit: + print(f"Text: {unit['text']}") + if 'source' in unit: + print(f"Source: {unit['source']}") + + # Handle relationships + if hasattr(context, 'relationships'): + print("\nRelationships:") + for i, rel in enumerate(context.relationships, 1): + print(f"\nRelationship {i}: {rel}") + + # Final fallback to context_chunks + if not (hasattr(context, 'context_records') or + hasattr(context, 'reports') or + hasattr(context, 'text_units') or + hasattr(context, 'relationships')): + if hasattr(context, 'context_chunks'): + print("\nContext Chunks:") + chunks = context.context_chunks + if isinstance(chunks, dict): + for key, value in chunks.items(): + print(f"\n{key}:") + if isinstance(value, list): + for i, item in enumerate(value, 1): + if isinstance(item, dict): + print(f"\nItem {i}:") + for k, v in item.items(): + print(f"{k}: {v}") + else: + print(f"\nItem {i}: {item}") + else: + print(value) + elif isinstance(chunks, list): + for i, chunk in enumerate(chunks, 1): + if isinstance(chunk, dict): + print(f"\nChunk {i}:") + for k, v in chunk.items(): + print(f"{k}: {v}") + else: + print(f"\nChunk {i}: {chunk}") + + # If nothing was found, print debug info + if not any([hasattr(context, attr) for attr in ['context_records', 'reports', 'text_units', 'relationships', 'context_chunks']]): + # print("\nDebug Info:") + # print(f"Context type: {type(context)}") + # print(f"Available attributes: {dir(context)}") + print(f"Raw context: {context}") + + print("\n=== END RAW CHUNKS ===\n") + + except Exception as e: + print(f"\nError displaying chunks: {str(e)}") + print(f"Context type: {type(context)}") + print(f"Context attributes: {dir(context)}") + + def run_global_search( config_filepath: Path | None, data_dir: Path | None, @@ -31,16 +192,24 @@ def run_global_search( response_type: str, streaming: bool, query: str, + raw_chunks: bool = False ): """Perform a global search with a given query. Loads index files required for global search and calls the Query API. """ + #print(f"\nDEBUG: run_global_search called with raw_chunks={raw_chunks}") + root = root_dir.resolve() cli_overrides = {} if data_dir: cli_overrides["output.base_dir"] = str(data_dir) config = load_config(root, config_filepath, cli_overrides) + + # Initialize callbacks list + callbacks = [] + if raw_chunks: + callbacks.append(RawChunksCallback()) dataframe_dict = _resolve_output_files( config=config, @@ -75,6 +244,7 @@ def run_global_search( response_type=response_type, streaming=streaming, query=query, + callbacks=callbacks ) ) logger.success(f"Global Search Response:\n{response}") @@ -88,7 +258,6 @@ def run_global_search( final_community_reports: pd.DataFrame = dataframe_dict["community_reports"] if streaming: - async def run_streaming_search(): full_response = "" context_data = {} @@ -97,8 +266,8 @@ def on_context(context: Any) -> None: nonlocal context_data context_data = context - callbacks = NoopQueryCallbacks() - callbacks.on_context = on_context + global_callbacks = callbacks + [NoopQueryCallbacks()] # Combine with existing callbacks + global_callbacks[-1].on_context = on_context async for stream_chunk in api.global_search_streaming( config=config, @@ -109,7 +278,7 @@ def on_context(context: Any) -> None: dynamic_community_selection=dynamic_community_selection, response_type=response_type, query=query, - callbacks=[callbacks], + callbacks=global_callbacks, # Use combined callbacks ): full_response += stream_chunk print(stream_chunk, end="") # noqa: T201 @@ -129,6 +298,7 @@ def on_context(context: Any) -> None: dynamic_community_selection=dynamic_community_selection, response_type=response_type, query=query, + callbacks=callbacks ) ) logger.success(f"Global Search Response:\n{response}") @@ -137,6 +307,8 @@ def on_context(context: Any) -> None: return response, context_data + + def run_local_search( config_filepath: Path | None, data_dir: Path | None, @@ -145,17 +317,26 @@ def run_local_search( response_type: str, streaming: bool, query: str, + raw_chunks: bool = False, ): """Perform a local search with a given query. Loads index files required for local search and calls the Query API. """ + # Add debug print at start of function + print(f"\nDEBUG: run_local_search called with raw_chunks={raw_chunks}") + root = root_dir.resolve() cli_overrides = {} if data_dir: cli_overrides["output.base_dir"] = str(data_dir) config = load_config(root, config_filepath, cli_overrides) - + + # Initialize callbacks list + callbacks = [] + if raw_chunks: + callbacks.append(RawChunksCallback()) + dataframe_dict = _resolve_output_files( config=config, output_list=[ @@ -169,6 +350,7 @@ def run_local_search( "covariates", ], ) + # Call the Multi-Index Local Search API if dataframe_dict["multi-index"]: final_entities_list = dataframe_dict["entities"] @@ -202,6 +384,7 @@ def run_local_search( response_type=response_type, streaming=streaming, query=query, + callbacks=callbacks, ) ) logger.success(f"Local Search Response:\n{response}") @@ -218,7 +401,6 @@ def run_local_search( final_covariates: pd.DataFrame | None = dataframe_dict["covariates"] if streaming: - async def run_streaming_search(): full_response = "" context_data = {} @@ -227,8 +409,8 @@ def on_context(context: Any) -> None: nonlocal context_data context_data = context - callbacks = NoopQueryCallbacks() - callbacks.on_context = on_context + local_callbacks = callbacks + [NoopQueryCallbacks()] + local_callbacks[-1].on_context = on_context async for stream_chunk in api.local_search_streaming( config=config, @@ -241,30 +423,31 @@ def on_context(context: Any) -> None: community_level=community_level, response_type=response_type, query=query, - callbacks=[callbacks], + callbacks=local_callbacks, ): full_response += stream_chunk - print(stream_chunk, end="") # noqa: T201 - sys.stdout.flush() # flush output buffer to display text immediately - print() # noqa: T201 + print(stream_chunk, end="") + sys.stdout.flush() + print() return full_response, context_data return asyncio.run(run_streaming_search()) - # not streaming - response, context_data = asyncio.run( - api.local_search( - config=config, - entities=final_entities, - communities=final_communities, - community_reports=final_community_reports, - text_units=final_text_units, - relationships=final_relationships, - covariates=final_covariates, - community_level=community_level, - response_type=response_type, - query=query, + else: + response, context_data = asyncio.run( + api.local_search( + config=config, + entities=final_entities, + communities=final_communities, + community_reports=final_community_reports, + text_units=final_text_units, + relationships=final_relationships, + covariates=final_covariates, + community_level=community_level, + response_type=response_type, + query=query, + callbacks=callbacks, + ) ) - ) logger.success(f"Local Search Response:\n{response}") # NOTE: we return the response and context data here purely as a complete demonstration of the API. # External users should use the API directly to get the response and context data. @@ -279,16 +462,24 @@ def run_drift_search( response_type: str, streaming: bool, query: str, + raw_chunks: bool = False # Added raw_chunks parameter ): """Perform a local search with a given query. Loads index files required for local search and calls the Query API. """ + print(f"\nDEBUG: run_drift_search called with raw_chunks={raw_chunks}") + root = root_dir.resolve() cli_overrides = {} if data_dir: cli_overrides["output.base_dir"] = str(data_dir) config = load_config(root, config_filepath, cli_overrides) + + # Initialize callbacks list + callbacks = [] + if raw_chunks: + callbacks.append(RawChunksCallback()) dataframe_dict = _resolve_output_files( config=config, @@ -327,6 +518,7 @@ def run_drift_search( response_type=response_type, streaming=streaming, query=query, + callbacks=callbacks # Added callbacks parameter ) ) logger.success(f"DRIFT Search Response:\n{response}") @@ -351,8 +543,9 @@ def on_context(context: Any) -> None: nonlocal context_data context_data = context - callbacks = NoopQueryCallbacks() - callbacks.on_context = on_context + drift_callbacks = callbacks + [NoopQueryCallbacks()] # Combine with existing callbacks + drift_callbacks[-1].on_context = on_context + async for stream_chunk in api.drift_search_streaming( config=config, @@ -364,7 +557,7 @@ def on_context(context: Any) -> None: community_level=community_level, response_type=response_type, query=query, - callbacks=[callbacks], + callbacks=drift_callbacks, # Use combined callbacks ): full_response += stream_chunk print(stream_chunk, end="") # noqa: T201 @@ -386,6 +579,7 @@ def on_context(context: Any) -> None: community_level=community_level, response_type=response_type, query=query, + callbacks=callbacks # Added callbacks parameter ) ) logger.success(f"DRIFT Search Response:\n{response}") diff --git a/graphrag/language_model/providers/fnllm/utils.py b/graphrag/language_model/providers/fnllm/utils.py index f50b0250e2..c82fc7edf9 100644 --- a/graphrag/language_model/providers/fnllm/utils.py +++ b/graphrag/language_model/providers/fnllm/utils.py @@ -126,7 +126,6 @@ def run_coroutine_sync(coroutine: Coroutine[Any, Any, T]) -> T: future = asyncio.run_coroutine_threadsafe(coroutine, _loop) return future.result() - def is_reasoning_model(model: str) -> bool: """Return whether the model uses a known OpenAI reasoning model.""" return model.lower() in {"o1", "o1-mini", "o3-mini"} diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index 907c83cacf..278577063e 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -39,15 +39,16 @@ def get_local_search_engine( config: GraphRagConfig, - reports: list[CommunityReport], - text_units: list[TextUnit], - entities: list[Entity], - relationships: list[Relationship], + reports: dict[str, list[CommunityReport]], + text_units: dict[str, list[TextUnit]], + entities: dict[str, list[Entity]], + relationships: dict[str, list[Relationship]], covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, system_prompt: str | None = None, callbacks: list[QueryCallbacks] | None = None, + raw_chunks: bool = True, ) -> LocalSearch: """Create a local search engine based on data + configuration.""" model_settings = config.get_language_model_config(config.local_search.chat_model_id) @@ -110,13 +111,14 @@ def get_local_search_engine( "include_community_rank": False, "return_candidate_context": False, "embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids - "max_context_tokens": ls_config.max_context_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + "max_tokens": ls_config.max_tokens # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + }, response_type=response_type, callbacks=callbacks, + raw_chunks=raw_chunks # Only pass raw_chunks here ) - def get_global_search_engine( config: GraphRagConfig, reports: list[CommunityReport], @@ -128,6 +130,7 @@ def get_global_search_engine( reduce_system_prompt: str | None = None, general_knowledge_inclusion_prompt: str | None = None, callbacks: list[QueryCallbacks] | None = None, + raw_chunks: bool = True, # Added raw_chunks parameter ) -> GlobalSearch: """Create a global search engine based on data + configuration.""" model_settings = config.get_language_model_config( @@ -200,6 +203,7 @@ def get_global_search_engine( concurrent_coroutines=model_settings.concurrent_requests, response_type=response_type, callbacks=callbacks, + raw_chunks=raw_chunks # Added raw_chunks parameter ) @@ -214,6 +218,7 @@ def get_drift_search_engine( local_system_prompt: str | None = None, reduce_system_prompt: str | None = None, callbacks: list[QueryCallbacks] | None = None, + raw_chunks: bool = True, # Added raw_chunks parameter ) -> DRIFTSearch: """Create a local search engine based on data + configuration.""" chat_model_settings = config.get_language_model_config( @@ -267,6 +272,7 @@ def get_drift_search_engine( ), token_encoder=token_encoder, callbacks=callbacks, + raw_chunks=raw_chunks # Added raw_chunks parameter ) diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index 10d6234327..cb27e7c6f4 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -41,6 +41,7 @@ def __init__( token_encoder: tiktoken.Encoding | None = None, query_state: QueryState | None = None, callbacks: list[QueryCallbacks] | None = None, + raw_chunks: bool = True, # Added raw_chunks parameter ): """ Initialize the DRIFTSearch class. @@ -63,6 +64,7 @@ def __init__( token_encoder=token_encoder, ) self.callbacks = callbacks or [] + self.raw_chunks = raw_chunks # Store raw_chunks parameter self.local_search = self.init_local_search() def init_local_search(self) -> LocalSearch: @@ -105,6 +107,7 @@ def init_local_search(self) -> LocalSearch: context_builder_params=local_context_params, response_type="multiple paragraphs", callbacks=self.callbacks, + raw_chunks=self.raw_chunks, # Pass raw_chunks to LocalSearch ) def _process_primer_results( @@ -209,8 +212,17 @@ async def search( # Check if query state is empty if not self.query_state.graph: - # Prime the search with the primer + if self.raw_chunks: + print("\n=== STEP 1: PRIMER SEARCH ===") + print("Query:", query) + primer_context, token_ct = await self.context_builder.build_context(query) + + if self.raw_chunks: + print("\nPrimer Context:") + print(primer_context) + + llm_calls["build_context"] = token_ct["llm_calls"] prompt_tokens["build_context"] = token_ct["prompt_tokens"] output_tokens["build_context"] = token_ct["output_tokens"] @@ -218,6 +230,13 @@ async def search( primer_response = await self.primer.search( query=query, top_k_reports=primer_context ) + + if self.raw_chunks: + print("\nPrimer Response:") + print(primer_response.response) + print("=== END PRIMER SEARCH ===\n") + + llm_calls["primer"] = primer_response.llm_calls prompt_tokens["primer"] = primer_response.prompt_tokens output_tokens["primer"] = primer_response.output_tokens @@ -231,6 +250,10 @@ async def search( epochs = 0 llm_call_offset = 0 while epochs < self.context_builder.config.n_depth: + + if self.raw_chunks: + print(f"\n=== STEP 2: ACTION SEARCH (Epoch {epochs + 1}) ===") + actions = self.query_state.rank_incomplete_actions() if len(actions) == 0: log.info("No more actions to take. Exiting DRIFT loop.") @@ -239,10 +262,25 @@ async def search( llm_call_offset += ( len(actions) - self.context_builder.config.drift_k_followups ) + + if self.raw_chunks: + print(f"\nProcessing {len(actions)} actions:") + for i, action in enumerate(actions, 1): + print(f"\nAction {i}:") + print(f"Query: {action.query}") + print(f"Follow-ups: {action.follow_ups}") + # Process actions results = await self._search_step( global_query=query, search_engine=self.local_search, actions=actions ) + + if self.raw_chunks: + print("\nAction Results:") + for i, result in enumerate(results, 1): + print(f"\nResult {i}:") + print(result.response if hasattr(result, 'response') else result) + print(f"=== END ACTION SEARCH (Epoch {epochs + 1}) ===\n") # Update query state for action in results: @@ -265,6 +303,13 @@ async def search( reduced_response = response_state if reduce: + + if self.raw_chunks: + print("\n=== STEP 3: REDUCTION ===") + print("Response state to be reduced:") + print(response_state) + + # Reduce response_state to a single comprehensive response for callback in self.callbacks: callback.on_reduce_response_start(response_state) @@ -284,6 +329,11 @@ async def search( output_tokens=output_tokens, model_params=model_params, ) + + if self.raw_chunks: + print("\nReduced Response:") + print(reduced_response) + print("=== END REDUCTION ===\n") for callback in self.callbacks: callback.on_reduce_response_end(reduced_response) diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index b7f75a43ee..cc605c120f 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -67,6 +67,7 @@ def __init__( reduce_max_length: int = 2000, context_builder_params: dict[str, Any] | None = None, concurrent_coroutines: int = 32, + raw_chunks: bool = True, # Added raw_chunks parameter ): super().__init__( model=model, @@ -83,6 +84,7 @@ def __init__( ) self.callbacks = callbacks or [] self.max_data_tokens = max_data_tokens + self.raw_chunks = raw_chunks # Store raw_chunks parameter self.map_llm_params = map_llm_params if map_llm_params else {} self.reduce_llm_params = reduce_llm_params if reduce_llm_params else {} @@ -155,6 +157,20 @@ async def search( conversation_history=conversation_history, **self.context_builder_params, ) + + # Print raw chunks if enabled + if self.raw_chunks: + print("\n=== CONTEXT SENT TO LLM (GLOBAL SEARCH) ===") + print("\nInitial Context Chunks:") + print(context_result.context_chunks) + print("\nCommunity Reports:") + if hasattr(context_result, 'community_reports'): + for i, report in enumerate(context_result.community_reports, 1): + print(f"\nReport {i}:") + print(report) + print("=== END INITIAL CONTEXT ===\n") + + llm_calls["build_context"] = context_result.llm_calls prompt_tokens["build_context"] = context_result.prompt_tokens output_tokens["build_context"] = context_result.output_tokens @@ -171,6 +187,14 @@ async def search( ) for data in context_result.context_chunks ]) + + # Print map responses if raw_chunks is enabled + if self.raw_chunks: + print("\n=== MAP RESPONSES ===") + for i, response in enumerate(map_responses, 1): + print(f"\nBatch {i} Response:") + print(response.response) + print("=== END MAP RESPONSES ===\n") for callback in self.callbacks: callback.on_map_response_end(map_responses) @@ -186,6 +210,15 @@ async def search( query=query, **self.reduce_llm_params, ) + + # Print reduce context if raw_chunks is enabled + if self.raw_chunks: + print("\n=== REDUCE CONTEXT ===") + print("\nReduce Input:") + print(reduce_response.context_text) + print("=== END REDUCE CONTEXT ===\n") + + llm_calls["reduce"] = reduce_response.llm_calls prompt_tokens["reduce"] = reduce_response.prompt_tokens output_tokens["reduce"] = reduce_response.output_tokens diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index 3a02caaf44..bea20eb29b 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -38,6 +38,7 @@ def __init__( callbacks: list[QueryCallbacks] | None = None, model_params: dict[str, Any] | None = None, context_builder_params: dict | None = None, + raw_chunks: bool = True, ): super().__init__( model=model, @@ -49,6 +50,7 @@ def __init__( self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT self.callbacks = callbacks or [] self.response_type = response_type + self.raw_chunks = raw_chunks async def search( self, @@ -66,6 +68,13 @@ async def search( **kwargs, **self.context_builder_params, ) + + if self.raw_chunks: + print("\n=== CONTEXT SENT TO LLM ===") + print(f"Context chunks used for LLM prompt:") + print(context_result.context_chunks) + print("=== END CONTEXT ===\n") + llm_calls["build_context"] = context_result.llm_calls prompt_tokens["build_context"] = context_result.prompt_tokens output_tokens["build_context"] = context_result.output_tokens @@ -90,6 +99,10 @@ async def search( full_response = "" + # Call callbacks with context before formatting prompt + for callback in self.callbacks: + callback.on_context(context_result) + async for response in self.model.achat_stream( prompt=query, history=history_messages, @@ -145,9 +158,14 @@ async def stream_search( **self.context_builder_params, ) log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query) + + search_prompt = self.system_prompt.format( - context_data=context_result.context_chunks, response_type=self.response_type + context_data=context_result.context_chunks, + response_type=self.response_type ) + + history_messages = [ {"role": "system", "content": search_prompt}, ]