Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from langchain_core.embeddings import Embeddings
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator

from databricks_langchain import DatabricksEmbeddings
from databricks_langchain.vectorstores import DatabricksVectorSearch
Expand Down Expand Up @@ -72,6 +72,12 @@ def _validate_tool_inputs(self):
IndexDetails(dbvs.index),
)

# Create args_schema based on dynamic_filter setting
if self.dynamic_filter:
self.args_schema = self._create_enhanced_input_model()
else:
self.args_schema = self._create_basic_input_model()

return self

@vector_search_retriever_tool_trace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,129 @@ def test_kwargs_override_both_num_results_and_query_type() -> None:
query_type="HYBRID", # Should use overridden value
filter={},
)


def test_enhanced_filter_description_with_column_metadata() -> None:
"""Test that the tool args_schema includes enhanced filter descriptions with column metadata."""
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True)

# The LangChain implementation calls index.describe() to get column information
# and includes them in the filter description
args_schema = vector_search_tool.args_schema
filter_field = args_schema.model_fields["filters"]

# Check that the filter description is enhanced with available columns
# Note: The actual columns will depend on the mocked index.describe() response
assert "Available columns for filtering:" in filter_field.description or "Optional filters" in filter_field.description

# Should include comprehensive filter syntax
assert "Inclusion:" in filter_field.description
assert "Exclusion:" in filter_field.description
assert "Comparisons:" in filter_field.description
assert "Pattern match:" in filter_field.description
assert "OR logic:" in filter_field.description

# Should include examples
assert "Examples:" in filter_field.description
assert 'Filter by category:' in filter_field.description
assert 'Filter by price range:' in filter_field.description


def test_enhanced_filter_description_without_column_metadata() -> None:
"""Test that the tool args_schema gracefully handles missing column metadata."""
from unittest.mock import Mock

with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class:
mock_ws_client = Mock()
mock_ws_client.tables.get.side_effect = Exception("Cannot retrieve table info")
mock_ws_client_class.return_value = mock_ws_client

vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True)

# Check that the args_schema still includes filter description
args_schema = vector_search_tool.args_schema
filter_field = args_schema.model_fields["filters"]

# Should not include available columns section
assert "Available columns for filtering:" not in filter_field.description

# Should still include comprehensive filter syntax
assert "Inclusion:" in filter_field.description
assert "Exclusion:" in filter_field.description
assert "Comparisons:" in filter_field.description
assert "Pattern match:" in filter_field.description
assert "OR logic:" in filter_field.description

# Should still include examples
assert "Examples:" in filter_field.description


def test_cannot_use_both_dynamic_filter_and_predefined_filters() -> None:
"""Test that using both dynamic_filter and predefined filters raises an error."""
# Try to initialize tool with both dynamic_filter=True and predefined filters
with pytest.raises(ValueError, match="Cannot use both dynamic_filter=True and predefined filters"):
init_vector_search_tool(
DELTA_SYNC_INDEX,
filters={"status": "active", "category": "electronics"},
dynamic_filter=True
)


def test_predefined_filters_work_without_dynamic_filter() -> None:
"""Test that predefined filters work correctly when dynamic_filter is False."""
# Initialize tool with only predefined filters (dynamic_filter=False by default)
vector_search_tool = init_vector_search_tool(
DELTA_SYNC_INDEX,
filters={"status": "active", "category": "electronics"}
)

# The filters parameter should NOT be exposed since dynamic_filter=False
args_schema = vector_search_tool.args_schema
assert "filters" not in args_schema.model_fields

# Test that predefined filters are used
vector_search_tool._vector_store.similarity_search = MagicMock()

vector_search_tool.invoke({
"query": "what electronics are available"
})

vector_search_tool._vector_store.similarity_search.assert_called_once_with(
query="what electronics are available",
k=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
filter={"status": "active", "category": "electronics"}, # Only predefined filters
)


def test_filter_item_serialization() -> None:
"""Test that FilterItem objects are properly converted to dictionaries."""
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)
vector_search_tool._vector_store.similarity_search = MagicMock()

# Test various filter types
filters = [
FilterItem(key="category", value="electronics"),
FilterItem(key="price >=", value=100),
FilterItem(key="status NOT", value="discontinued"),
FilterItem(key="tags", value=["wireless", "bluetooth"]),
]

vector_search_tool.invoke({
"query": "find products",
"filters": filters
})

expected_filters = {
"category": "electronics",
"price >=": 100,
"status NOT": "discontinued",
"tags": ["wireless", "bluetooth"]
}

vector_search_tool._vector_store.similarity_search.assert_called_once_with(
query="find products",
k=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
filter=expected_filters,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we add error handling per my comment in src/databricks_ai_bridge/vector_search_retriever_tool.py - can we add a test that simulates failure during WorkspaceClient.tables.get() process

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
VectorSearchRetrieverToolMixin,
vector_search_retriever_tool_trace,
)
from pydantic import Field, PrivateAttr, model_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator

from openai import OpenAI, pydantic_function_tool
from openai.types.chat import ChatCompletionToolParam
Expand Down Expand Up @@ -145,8 +145,14 @@ def _validate_tool_inputs(self):

tool_name = self._get_tool_name()

# Create tool input model based on dynamic_filter setting
if self.dynamic_filter:
tool_input_class = self._create_enhanced_input_model()
else:
tool_input_class = self._create_basic_input_model()

self.tool = pydantic_function_tool(
VectorSearchRetrieverToolInput,
tool_input_class,
name=tool_name,
description=self.tool_description
or self._get_default_tool_description(self._index_details),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,169 @@ def test_kwargs_override_both_num_results_and_query_type() -> None:
query_type="HYBRID", # Should use overridden value
query_vector=None,
)


def test_get_filter_param_description_with_column_metadata() -> None:
"""Test that _get_filter_param_description includes column metadata when available."""
# Mock table info with column metadata
mock_column1 = Mock()
mock_column1.name = "category"
mock_column1.type_name.name = "STRING"

mock_column2 = Mock()
mock_column2.name = "price"
mock_column2.type_name.name = "FLOAT"

mock_column3 = Mock()
mock_column3.name = "__internal_column" # Should be excluded
mock_column3.type_name.name = "STRING"

mock_table_info = Mock()
mock_table_info.columns = [mock_column1, mock_column2, mock_column3]

with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class:
mock_ws_client = Mock()
mock_ws_client.tables.get.return_value = mock_table_info
mock_ws_client_class.return_value = mock_ws_client

vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)

# Test the _get_filter_param_description method directly
description = vector_search_tool._get_filter_param_description()

# Should include available columns in description
assert "Available columns for filtering: category (STRING), price (FLOAT)" in description

# Should include comprehensive filter syntax
assert "Inclusion:" in description
assert "Exclusion:" in description
assert "Comparisons:" in description
assert "Pattern match:" in description
assert "OR logic:" in description

# Should include examples
assert "Examples:" in description
assert 'Filter by category:' in description
assert 'Filter by price range:' in description


def test_enhanced_filter_description_used_in_tool_schema() -> None:
"""Test that the tool schema includes comprehensive filter descriptions."""
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True)

# Check that the tool schema includes enhanced filter description
tool_schema = vector_search_tool.tool
filter_param = tool_schema["function"]["parameters"]["properties"]["filters"]

# Check that it includes the comprehensive filter syntax
assert "Inclusion:" in filter_param["description"]
assert "Exclusion:" in filter_param["description"]
assert "Comparisons:" in filter_param["description"]
assert "Pattern match:" in filter_param["description"]
assert "OR logic:" in filter_param["description"]

# Check that it includes useful filter information
assert "array of key-value pairs" in filter_param["description"]
assert "column" in filter_param["description"]


def test_enhanced_filter_description_without_column_metadata() -> None:
"""Test that the tool schema gracefully handles missing column metadata."""
with patch("databricks.sdk.WorkspaceClient") as mock_ws_client_class:
mock_ws_client = Mock()
mock_ws_client.tables.get.side_effect = Exception("Cannot retrieve table info")
mock_ws_client_class.return_value = mock_ws_client

vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True)

# Check that the tool schema still includes filter description
tool_schema = vector_search_tool.tool
filter_param = tool_schema["function"]["parameters"]["properties"]["filters"]

# Should not include available columns section
assert "Available columns for filtering:" not in filter_param["description"]

# Should still include comprehensive filter syntax
assert "Inclusion:" in filter_param["description"]
assert "Exclusion:" in filter_param["description"]
assert "Comparisons:" in filter_param["description"]
assert "Pattern match:" in filter_param["description"]
assert "OR logic:" in filter_param["description"]

# Should still include examples
assert "Examples:" in filter_param["description"]


def test_cannot_use_both_dynamic_filter_and_predefined_filters() -> None:
"""Test that using both dynamic_filter and predefined filters raises an error."""
# Try to initialize tool with both dynamic_filter=True and predefined filters
with pytest.raises(ValueError, match="Cannot use both dynamic_filter=True and predefined filters"):
init_vector_search_tool(
DELTA_SYNC_INDEX,
filters={"status": "active", "category": "electronics"},
dynamic_filter=True
)


def test_predefined_filters_work_without_dynamic_filter() -> None:
"""Test that predefined filters work correctly when dynamic_filter is False."""
# Initialize tool with only predefined filters (dynamic_filter=False by default)
vector_search_tool = init_vector_search_tool(
DELTA_SYNC_INDEX,
filters={"status": "active", "category": "electronics"}
)

# The filters parameter should NOT be exposed since dynamic_filter=False
tool_schema = vector_search_tool.tool
assert "filters" not in tool_schema["function"]["parameters"]["properties"]

# Test that predefined filters are used
vector_search_tool._index.similarity_search = MagicMock()

vector_search_tool.execute(
query="what electronics are available"
)

vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="what electronics are available",
filters={"status": "active", "category": "electronics"}, # Only predefined filters
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
)


def test_filter_item_serialization() -> None:
"""Test that FilterItem objects are properly converted to dictionaries."""
vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX)
vector_search_tool._index.similarity_search = MagicMock()

# Test various filter types
filters = [
FilterItem(key="category", value="electronics"),
FilterItem(key="price >=", value=100),
FilterItem(key="status NOT", value="discontinued"),
FilterItem(key="tags", value=["wireless", "bluetooth"]),
]

vector_search_tool.execute(
"find products",
filters=filters
)

expected_filters = {
"category": "electronics",
"price >=": 100,
"status NOT": "discontinued",
"tags": ["wireless", "bluetooth"]
}

vector_search_tool._index.similarity_search.assert_called_once_with(
columns=vector_search_tool.columns,
query_text="find products",
filters=expected_filters,
num_results=vector_search_tool.num_results,
query_type=vector_search_tool.query_type,
query_vector=None,
)
Loading
Loading