Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,22 @@ def _prepare_prompt(
# Retrieve RAG content
if rag_retriever:
retrieved_nodes = rag_retriever.retrieve(query)
logger.info("Retrieved %d documents from indexes", len(retrieved_nodes))

retrieved_nodes = reranker.rerank(retrieved_nodes)
logger.info("After reranking: %d documents", len(retrieved_nodes))

# Logging top retrieved candidates with scores
for i, node in enumerate(retrieved_nodes[:5]):
logger.info(
"Retrieved doc #%d: title='%s', url='%s', index='%s', score=%.4f",
i + 1,
node.metadata.get("title", "unknown"),
node.metadata.get("docs_url", "unknown"),
node.metadata.get("index_origin", "unknown"),
node.get_score(raise_error=False),
)

rag_chunks, available_tokens = token_handler.truncate_rag_context(
retrieved_nodes, available_tokens
)
Expand Down
51 changes: 48 additions & 3 deletions ols/src/rag_index/index_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,17 @@ class QueryFusionRetrieverCustom(QueryFusionRetriever): # pylint: disable=W0612

def __init__(self, **kwargs):
"""Initialize custom query fusion class."""
# Extract custom parameters before passing to parent
retriever_weights = kwargs.pop("retriever_weights", None)
index_configs = kwargs.pop("index_configs", None)
retrievers = kwargs.get("retrievers", [])

super().__init__(**kwargs)

retriever_weights = kwargs.get("retriever_weights", None)
if not retriever_weights:
retriever_weights = [1.0] * len(kwargs["retrievers"])
retriever_weights = [1.0] * len(retrievers)
self._custom_retriever_weights = retriever_weights
self._index_configs = index_configs

def _simple_fusion(self, results):
"""Override internal method and apply weighted score."""
Expand All @@ -72,16 +77,38 @@ def _simple_fusion(self, results):
# Current dynamic weights marginally penalize the score.
all_nodes = {}
for i, nodes_with_scores in enumerate(results.values()):
# Getting index metadata based on available index configs
index_id = ""
index_origin = ""
if self._index_configs and i < len(self._index_configs):
index_config = self._index_configs[i]
if index_config is not None:
index_id = index_config.product_docs_index_id or ""
index_origin = index_config.product_docs_origin or "default"

for j, node_with_score in enumerate(nodes_with_scores):
# Add index metadata to node
node_with_score.node.metadata["index_id"] = index_id
node_with_score.node.metadata["index_origin"] = index_origin

node_index_id = f"{i}_{j}"
all_nodes[node_index_id] = node_with_score
# weighted_score = node_with_score.score * self._custom_retriever_weights[i]
# Uncomment above and delete below, if we decide weights to be set from config.
weighted_score = node_with_score.score * (
original_score = node_with_score.score
weighted_score = original_score * (
1 - min(i, SCORE_DILUTION_DEPTH - 1) * SCORE_DILUTION_WEIGHT
)
all_nodes[node_index_id].score = weighted_score

logger.debug(
"Document from index #%d (%s): original_score=%.4f, weighted_score=%.4f",
i,
index_origin or index_id or "unknown",
original_score,
weighted_score,
)

return sorted(
all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True
)
Expand All @@ -95,6 +122,7 @@ def __init__(self, index_config: Optional[ReferenceContent]) -> None:
load_llama_index_deps()
self._indexes = None
self._retriever = None
self._loaded_index_configs = None

self._index_config = index_config
logger.debug("Config used for index load: %s", str(self._index_config))
Expand Down Expand Up @@ -132,6 +160,7 @@ def _load_index(self) -> None:
Settings.llm = resolve_llm(None)

indexes = []
loaded_configs = []
for i, index_config in enumerate(self._index_config.indexes):
if index_config.product_docs_index_path is None:
logger.warning("Index path is not set for index #%d, skip loading.", i)
Expand Down Expand Up @@ -159,13 +188,15 @@ def _load_index(self) -> None:
index_id=index_config.product_docs_index_id,
)
indexes.append(index)
loaded_configs.append(index_config)
logger.info("Vector index #%d is loaded.", i)
except Exception as err:
logger.exception(
"Error loading vector index #%d:\n%s, skipped.", i, err
)
if len(indexes) == 0:
logger.warning("No indexes are loaded.")
self._loaded_index_configs = loaded_configs
return
if len(indexes) < len(self._index_config.indexes):
logger.warning(
Expand All @@ -175,6 +206,7 @@ def _load_index(self) -> None:
else:
logger.info("All indexes are loaded.")
self._indexes = indexes
self._loaded_index_configs = loaded_configs

@property
def vector_indexes(self) -> Optional[list[BaseIndex]]:
Expand All @@ -199,6 +231,18 @@ def get_retriever(
):
return self._retriever

# Log index information
index_info = [
f"{i}: {cfg.product_docs_origin or cfg.product_docs_index_id or 'unknown'}"
for i, cfg in enumerate(self._loaded_index_configs or [])
]
logger.info(
"Creating retriever for %d indexes (similarity_top_k=%d): %s",
len(self._indexes),
similarity_top_k,
index_info,
)

# Note: we are using a custom retriever, based on our need
retriever = QueryFusionRetrieverCustom(
retrievers=[
Expand All @@ -207,6 +251,7 @@ def get_retriever(
],
similarity_top_k=similarity_top_k,
retriever_weights=None, # Setting as None, until this gets added to config
index_configs=self._loaded_index_configs,
mode="simple", # Don't modify this as we are adding our own logic
num_queries=1, # set this to 1 to disable query generation
use_async=False,
Expand Down
55 changes: 47 additions & 8 deletions ols/utils/token_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,24 @@ def truncate_rag_context(
list of `RagChunk` objects, available tokens after context usage
"""
rag_chunks = []
logger.info(
"Processing %d retrieved nodes for RAG context", len(retrieved_nodes)
)

for node in retrieved_nodes:
for idx, node in enumerate(retrieved_nodes):
score = float(node.get_score(raise_error=False))
doc_title = node.metadata.get("title", "unknown")
doc_url = node.metadata.get("docs_url", "unknown")
index_id = node.metadata.get("index_id", "")
index_origin = node.metadata.get("index_origin", "")

if score < RAG_SIMILARITY_CUTOFF:
logger.debug(
"RAG content similarity score: %f is less than threshold %f.",
logger.info(
"Document #%d rejected: '%s' (index: %s) - "
"similarity score %.4f < threshold %.4f",
idx + 1,
doc_title,
index_origin or index_id or "unknown",
score,
RAG_SIMILARITY_CUTOFF,
)
Expand All @@ -137,26 +149,53 @@ def truncate_rag_context(
tokens = self.text_to_tokens(node_text)
tokens_count = TokenHandler._get_token_count(tokens)
tokens_count += 1 # for new-line char
logger.debug("RAG content tokens count: %d.", tokens_count)
logger.debug("RAG content tokens count: %d", tokens_count)

available_tokens = min(tokens_count, max_tokens)
logger.debug("Available tokens: %d.", tokens_count)
logger.debug(
"Tokens used for this chunk: %d, remaining budget: %d",
available_tokens,
max_tokens - available_tokens,
)

if available_tokens < MINIMUM_CONTEXT_TOKEN_LIMIT:
logger.debug("%d tokens are less than threshold.", available_tokens)
logger.info(
"Document #%d rejected: '%s' (index: %s) - "
"insufficient tokens (%d < %d minimum)",
idx + 1,
doc_title,
index_origin or index_id or "unknown",
available_tokens,
MINIMUM_CONTEXT_TOKEN_LIMIT,
)
break

logger.info(
"Document #%d selected: title='%s', url='%s', index='%s', "
"score=%.4f, tokens=%d, remaining_context=%d",
idx + 1,
doc_title,
doc_url,
index_origin or index_id or "unknown",
score,
available_tokens,
max_tokens - available_tokens,
)

node_text = self.tokens_to_text(tokens[:available_tokens])
rag_chunks.append(
RagChunk(
text=node_text,
doc_url=node.metadata.get("docs_url", ""),
doc_title=node.metadata.get("title", ""),
doc_url=doc_url,
doc_title=doc_title,
)
)

max_tokens -= available_tokens

logger.info(
"Final selection: %d documents chosen for RAG context", len(rag_chunks)
)
return rag_chunks, max_tokens

def limit_conversation_history(
Expand Down