diff --git a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py index ac8b62b8..edd26099 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py +++ b/integrations/langchain/src/databricks_langchain/vector_search_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Type +from typing import List, Optional, Type, Dict, Any from databricks_ai_bridge.utils.vector_search import IndexDetails from databricks_ai_bridge.vector_search_retriever_tool import ( @@ -21,26 +21,35 @@ class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): This class integrates with Databricks Vector Search and provides a convenient interface for building a retriever tool for agents. - **Note**: Any additional keyword arguments passed to the constructor will be passed along to - `databricks.vector_search.client.VectorSearchIndex.similarity_search` when executing the tool. `See - documentation `_ - to see the full set of supported keyword arguments, - e.g. `score_threshold`. Also, see documentation for - :class:`~databricks_ai_bridge.vector_search_retriever_tool.VectorSearchRetrieverToolMixin` for additional supported constructor - arguments not listed below, including `query_type` and `num_results`. + **Note**: Any additional keyword arguments passed to the constructor will be forwarded to + `databricks.vector_search.index.VectorSearchIndex.similarity_search` when executing the tool. + See documentation for the full set of supported keyword arguments (e.g., `score_threshold`). + Also see the mixin docs for additional supported constructor arguments (e.g., `query_type`, `num_results`). + + **New**: `client_args` (optional) is forwarded to `VectorSearchClient` via `DatabricksVectorSearch`. + Use this to pass service principal credentials (e.g., `service_principal_client_id`, + `service_principal_client_secret`) or other client options such as `disable_notice`. """ text_column: Optional[str] = Field( None, description="The name of the text column to use for the embeddings. " - "Required for direct-access index or delta-sync index with " - "self-managed embeddings.", + "Required for direct-access index or delta-sync index with self-managed embeddings.", ) embedding: Optional[Embeddings] = Field( None, description="Embedding model for self-managed embeddings." ) - # The BaseTool class requires 'name' and 'description' fields which we will populate in validate_tool_inputs() + # Optional pass-through for VectorSearchClient (SP/M2M auth, flags like disable_notice, etc.) + client_args: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Additional args forwarded to VectorSearchClient via DatabricksVectorSearch " + "(e.g., service_principal_client_id/service_principal_client_secret, disable_notice)." + ), + ) + + # BaseTool requires these; populated in validate_tool_inputs() name: str = Field(default="", description="The name of the tool") description: str = Field(default="", description="The description of the tool") args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput @@ -49,7 +58,7 @@ class VectorSearchRetrieverTool(BaseTool, VectorSearchRetrieverToolMixin): @model_validator(mode="after") def _validate_tool_inputs(self): - kwargs = { + kwargs: Dict[str, Any] = { "index_name": self.index_name, "embedding": self.embedding, "text_column": self.text_column, @@ -59,6 +68,9 @@ def _validate_tool_inputs(self): "workspace_client": self.workspace_client, "include_score": self.include_score, } + if self.client_args: + kwargs["client_args"] = self.client_args # <-- pass-through + dbvs = DatabricksVectorSearch(**kwargs) self._vector_store = dbvs @@ -71,7 +83,6 @@ def _validate_tool_inputs(self): (self.embedding.endpoint if isinstance(self.embedding, DatabricksEmbeddings) else None), IndexDetails(dbvs.index), ) - return self @vector_search_retriever_tool_trace diff --git a/integrations/langchain/tests/unit_tests/test_vs_retriever_tool_client_args.py b/integrations/langchain/tests/unit_tests/test_vs_retriever_tool_client_args.py new file mode 100644 index 00000000..2c8aea78 --- /dev/null +++ b/integrations/langchain/tests/unit_tests/test_vs_retriever_tool_client_args.py @@ -0,0 +1,56 @@ +# integrations/langchain/tests/unit_tests/test_vs_retriever_tool_client_args.py + +from typing import List +from langchain_core.embeddings import Embeddings + +def test_client_args_pass_through(monkeypatch): + captured = {} + + class FakeVSClient: + def __init__(self, **client_kwargs): + captured["kwargs"] = client_kwargs + + def get_index(self, **_): + # Return a fake index object with a minimal describe() the tool expects + class _Idx: + def describe(self): + return { + "name": "catalog.schema.index", + "endpoint_name": "vs_endpoint", + "index_type": "DELTA_SYNC", # validator sees delta-sync + "primary_key": "id", + "status": {"status": "ONLINE"}, + # (No need to declare managed vs self-managed if we supply embedding+text_column) + } + return _Idx() + + # Patch the canonical SDK client path + monkeypatch.setattr( + "databricks.vector_search.client.VectorSearchClient", + FakeVSClient, + raising=True, + ) + + # Minimal embeddings stub to satisfy self-managed requirement + class DummyEmbeddings(Embeddings): + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [[0.0] * 3 for _ in texts] + def embed_query(self, text: str) -> List[float]: + return [0.0, 0.0, 0.0] + + from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool + + tool = VectorSearchRetrieverTool( + index_name="catalog.schema.index", + text_column="body", # required with self-managed/delta-sync + embedding=DummyEmbeddings(), # satisfy validator + client_args={ + "service_principal_client_id": "abc", + "service_principal_client_secret": "xyz", + "disable_notice": True, + }, + ) + + assert captured["kwargs"]["service_principal_client_id"] == "abc" + assert captured["kwargs"]["service_principal_client_secret"] == "xyz" + assert captured["kwargs"]["disable_notice"] is True