Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Type
from typing import List, Optional, Type, Dict, Any

from databricks_ai_bridge.utils.vector_search import IndexDetails
from databricks_ai_bridge.vector_search_retriever_tool import (
Expand All @@ -21,26 +21,35 @@ class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin):
This class integrates with Databricks Vector Search and provides a convenient interface
for building a retriever tool for agents.

**Note**: Any additional keyword arguments passed to the constructor will be passed along to
`databricks.vector_search.client.VectorSearchIndex.similarity_search` when executing the tool. `See
documentation <https://api-docs.databricks.com/python/vector-search/databricks.vector_search.html#databricks.vector_search.index.VectorSearchIndex.similarity_search>`_
to see the full set of supported keyword arguments,
e.g. `score_threshold`. Also, see documentation for
:class:`~databricks_ai_bridge.vector_search_retriever_tool.VectorSearchRetrieverToolMixin` for additional supported constructor
arguments not listed below, including `query_type` and `num_results`.
**Note**: Any additional keyword arguments passed to the constructor will be forwarded to
`databricks.vector_search.index.VectorSearchIndex.similarity_search` when executing the tool.
See documentation for the full set of supported keyword arguments (e.g., `score_threshold`).
Also see the mixin docs for additional supported constructor arguments (e.g., `query_type`, `num_results`).

**New**: `client_args` (optional) is forwarded to `VectorSearchClient` via `DatabricksVectorSearch`.
Use this to pass service principal credentials (e.g., `service_principal_client_id`,
`service_principal_client_secret`) or other client options such as `disable_notice`.
"""

text_column: Optional[str] = Field(
None,
description="The name of the text column to use for the embeddings. "
"Required for direct-access index or delta-sync index with "
"self-managed embeddings.",
"Required for direct-access index or delta-sync index with self-managed embeddings.",
)
embedding: Optional[Embeddings] = Field(
None, description="Embedding model for self-managed embeddings."
)

# The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs()
# Optional pass-through for VectorSearchClient (SP/M2M auth, flags like disable_notice, etc.)
client_args: Optional[Dict[str, Any]] = Field(
default=None,
description=(
"Additional args forwarded to VectorSearchClient via DatabricksVectorSearch "
"(e.g., service_principal_client_id/service_principal_client_secret, disable_notice)."
),
)

# BaseTool requires these; populated in validate_tool_inputs()
name: str = Field(default="", description="The name of the tool")
description: str = Field(default="", description="The description of the tool")
args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput
Expand All @@ -49,7 +58,7 @@ class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin):

@model_validator(mode="after")
def _validate_tool_inputs(self):
kwargs = {
kwargs: Dict[str, Any] = {
"index_name": self.index_name,
"embedding": self.embedding,
"text_column": self.text_column,
Expand All @@ -59,6 +68,9 @@ def _validate_tool_inputs(self):
"workspace_client": self.workspace_client,
"include_score": self.include_score,
}
if self.client_args:
kwargs["client_args"] = self.client_args # <-- pass-through

dbvs = DatabricksVectorSearch(**kwargs)
self._vector_store = dbvs

Expand All @@ -71,7 +83,6 @@ def _validate_tool_inputs(self):
(self.embedding.endpoint if isinstance(self.embedding, DatabricksEmbeddings) else None),
IndexDetails(dbvs.index),
)

return self

@vector_search_retriever_tool_trace
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# integrations/langchain/tests/unit_tests/test_vs_retriever_tool_client_args.py

from typing import List
from langchain_core.embeddings import Embeddings

def test_client_args_pass_through(monkeypatch):
captured = {}

class FakeVSClient:
def __init__(self, **client_kwargs):
captured["kwargs"] = client_kwargs

def get_index(self, **_):
# Return a fake index object with a minimal describe() the tool expects
class _Idx:
def describe(self):
return {
"name": "catalog.schema.index",
"endpoint_name": "vs_endpoint",
"index_type": "DELTA_SYNC", # validator sees delta-sync
"primary_key": "id",
"status": {"status": "ONLINE"},
# (No need to declare managed vs self-managed if we supply embedding+text_column)
}
return _Idx()

# Patch the canonical SDK client path
monkeypatch.setattr(
"databricks.vector_search.client.VectorSearchClient",
FakeVSClient,
raising=True,
)

# Minimal embeddings stub to satisfy self-managed requirement
class DummyEmbeddings(Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [[0.0] * 3 for _ in texts]
def embed_query(self, text: str) -> List[float]:
return [0.0, 0.0, 0.0]

from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool

tool = VectorSearchRetrieverTool(
index_name="catalog.schema.index",
text_column="body", # required with self-managed/delta-sync
embedding=DummyEmbeddings(), # satisfy validator
client_args={
"service_principal_client_id": "abc",
"service_principal_client_secret": "xyz",
"disable_notice": True,
},
)

assert captured["kwargs"]["service_principal_client_id"] == "abc"
assert captured["kwargs"]["service_principal_client_secret"] == "xyz"
assert captured["kwargs"]["disable_notice"] is True