diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2f83adc6..051ed0e7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,3 +16,38 @@ If you are working with integration packages install them as well ```sh pip install -e "integrations/langchain[dev]" ``` + +## Running tests + +To run tests for the bridge library, use the following command: + +```sh +pytest tests +``` + +To run tests for the langchain integration library, use the following command: + +```sh +cd integrations/langchain +pytest tests +``` + +## Linting and formatting + +For formatting code run: + +```sh +ruff format +``` + +For checking linting rules run (required for CI checks): + +```sh +ruff check . +``` + +For fixing the linting issues such as import order, etc run: + +```sh +ruff check . --fix +``` diff --git a/integrations/langchain/README.md b/integrations/langchain/README.md index c60cef0d..e94de5a5 100644 --- a/integrations/langchain/README.md +++ b/integrations/langchain/README.md @@ -13,6 +13,15 @@ pip install databricks-langchain ### Install from source + +With https: + +```sh +pip install git+https://github.com/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain +``` + +With SSH: + ```sh pip install git+ssh://git@github.com/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain ``` @@ -36,3 +45,16 @@ from databricks_langchain.genie import GenieAgent genie_agent = GenieAgent("space-id", "Genie", description="This Genie space has access to sales data in Europe") ``` + +### (Preview) Use a Genie space as an tool + +> [!NOTE] +> Requires Genie API Private Preview. Reach out to your account team for enablement. + +Once the genie tool is created, you can then bind it to a [AgentExecutor](https://python.langchain.com/docs/how_to/agent_executor/#tools) or Langgraph React Agent. + +```python +from databricks_langchain.genie import GenieTool + +genie_tool = GenieTool("space-id", "Genie", "This Genie space has access to sales data in Europe") +``` diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 153c2df9..a5093e9e 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -1,4 +1,8 @@ -from databricks_ai_bridge.genie import Genie +import uuid +from typing import Optional, Tuple, Type + +from databricks_ai_bridge.genie import Genie, GenieResult +from pydantic import BaseModel, Field def _concat_messages_array(messages): @@ -32,6 +36,70 @@ def _query_genie_as_agent(input, genie_space_id, genie_agent_name): return {"messages": [AIMessage(content="")]} +class GenieToolInput(BaseModel): + question: str = Field(description="question to ask the agent") + summarized_chat_history: str = Field( + description="summarized chat history to provide the agent context of what may have been talked about. " + "Say 'No history' if there is no history to provide." + ) + + +def GenieTool(genie_space_id: str, genie_agent_name: str, genie_space_description: str): + from langchain_core.callbacks.manager import CallbackManagerForToolRun + from langchain_core.tools import BaseTool + + genie = Genie(genie_space_id) + + class GenieQuestionToolWithTrace(BaseTool): + name: str = f"{genie_agent_name}_details" + description: str = genie_space_description + args_schema: Type[BaseModel] = GenieToolInput + response_format: str = "content_and_artifact" + + def _run( + self, + question: str, + summarized_chat_history: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> Tuple[str, Optional[GenieResult]]: + message = ( + f"I will provide you a chat history, where your name is {genie_agent_name}. " + f"Please answer the following question: {question} with the following chat history " + f"for context: {summarized_chat_history}.\n" + ) + response = genie.ask_question_with_details(message) + if response: + return response.response, response + return "", None + + tool_with_details = GenieQuestionToolWithTrace() + + class GenieQuestionToolCall(BaseTool): + name: str = genie_agent_name + description: str = genie_space_description + args_schema: Type[BaseModel] = GenieToolInput + + def _run( + self, + question: str, + summarized_chat_history: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> Tuple[str, GenieResult]: + tool_result = tool_with_details.invoke( + { + "args": { + "question": question, + "summarized_chat_history": summarized_chat_history, + }, + "id": str(uuid.uuid4()), + "type": "tool_call", + } + ) + return tool_result.content + + return GenieQuestionToolCall() + + def GenieAgent(genie_space_id, genie_agent_name="Genie", description=""): """Create a genie agent that can be used to query the API""" from functools import partial diff --git a/integrations/langchain/tests/test_genie.py b/integrations/langchain/tests/test_genie.py index 70c6c287..60204213 100644 --- a/integrations/langchain/tests/test_genie.py +++ b/integrations/langchain/tests/test_genie.py @@ -1,9 +1,12 @@ from unittest.mock import patch +from databricks_ai_bridge.genie import GenieResult from langchain_core.messages import AIMessage from databricks_langchain.genie import ( GenieAgent, + GenieTool, + GenieToolInput, _concat_messages_array, _query_genie_as_agent, ) @@ -69,3 +72,41 @@ def test_create_genie_agent(MockRunnableLambda): # Check that the partial function is created with the correct arguments MockRunnableLambda.assert_called() + + +@patch("databricks_langchain.genie.Genie") +def test_create_genie_tool(MockGenie): + mock_genie = MockGenie.return_value + mock_genie.ask_question_with_details.return_value = GenieResult( + description=None, sql_query=None, response="It is sunny." + ) + + agent = GenieTool("space-id", "Genie", "Description") + + assert agent.name == "Genie" + assert agent.args_schema == GenieToolInput + assert agent.description == "Description" + assert ( + agent.invoke({"question": "What is the weather?", "summarized_chat_history": "No history"}) + == "It is sunny." + ) + + assert mock_genie.ask_question_with_details.call_count == 1 + + +@patch("databricks_langchain.genie.Genie") +def test_create_genie_tool_no_response(MockGenie): + mock_genie = MockGenie.return_value + mock_genie.ask_question_with_details.return_value = None + + agent = GenieTool("space-id", "Genie", "Description") + + assert agent.name == "Genie" + assert agent.args_schema == GenieToolInput + assert agent.description == "Description" + assert ( + agent.invoke({"question": "What is the weather?", "summarized_chat_history": "No history"}) + == "" + ) + + assert mock_genie.ask_question_with_details.call_count == 1 diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index 6cefde87..a25ac8f0 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -1,7 +1,8 @@ import logging import time +from dataclasses import dataclass from datetime import datetime -from typing import Union +from typing import Optional, Union import pandas as pd import tiktoken @@ -11,6 +12,13 @@ MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS +@dataclass(frozen=True, repr=True) +class GenieResult: + description: Optional[str] + sql_query: Optional[str] + response: Optional[str] # can be a query result or a text result + + # Define a function to count tokens def _count_tokens(text): encoding = tiktoken.encoding_for_model("gpt-4o") @@ -92,7 +100,7 @@ def create_message(self, conversation_id, content): return resp def poll_for_result(self, conversation_id, message_id): - def poll_result(): + def poll_result() -> Optional[GenieResult]: iteration_count = 0 while iteration_count < MAX_ITERATIONS: iteration_count += 1 @@ -107,7 +115,9 @@ def poll_result(): sql = query.get("query", "") logging.debug(f"Description: {description}") logging.debug(f"SQL: {sql}") - return poll_query_results() + return GenieResult( + sql_query=sql, description=description, response=poll_query_results() + ) elif resp["status"] == "COMPLETED": # Check if there is a query object in the attachments for the COMPLETED status query_attachment = next((r for r in resp["attachments"] if "query" in r), None) @@ -117,12 +127,15 @@ def poll_result(): sql = query.get("query", "") logging.debug(f"Description: {description}") logging.debug(f"SQL: {sql}") - return poll_query_results() + return GenieResult( + sql_query=sql, description=description, response=poll_query_results() + ) else: # Handle the text object in the COMPLETED status - return next(r for r in resp["attachments"] if "text" in r)["text"][ + text_content = next(r for r in resp["attachments"] if "text" in r)["text"][ "content" ] + return GenieResult(sql_query=None, description=None, response=text_content) elif resp["status"] == "FAILED": logging.debug("Genie failed to execute the query") return None @@ -157,7 +170,14 @@ def poll_query_results(): return poll_result() - def ask_question(self, question): + def ask_question(self, question) -> Optional[str]: + resp = self.start_conversation(question) + genie_result = self.poll_for_result(resp["conversation_id"], resp["message_id"]) + # ask question will just return the string response + if genie_result: + return genie_result.response + return None + + def ask_question_with_details(self, question: str) -> Optional[GenieResult]: resp = self.start_conversation(question) - # TODO (prithvi): return the query and the result return self.poll_for_result(resp["conversation_id"], resp["message_id"]) diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index f7a2cafa..66765c93 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from databricks_ai_bridge.genie import Genie, _count_tokens, _parse_query_result +from databricks_ai_bridge.genie import Genie, GenieResult, _count_tokens, _parse_query_result @pytest.fixture @@ -48,7 +48,7 @@ def test_poll_for_result_completed_with_text(genie, mock_workspace_client): {"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]}, ] result = genie.poll_for_result("123", "456") - assert result == "Result" + assert result.response == "Result" def test_poll_for_result_completed_with_query(genie, mock_workspace_client): @@ -65,7 +65,7 @@ def test_poll_for_result_completed_with_query(genie, mock_workspace_client): }, ] result = genie.poll_for_result("123", "456") - assert result == pd.DataFrame().to_markdown() + assert result.response == pd.DataFrame().to_markdown() def test_poll_for_result_executing_query(genie, mock_workspace_client): @@ -82,7 +82,7 @@ def test_poll_for_result_executing_query(genie, mock_workspace_client): }, ] result = genie.poll_for_result("123", "456") - assert result == pd.DataFrame().to_markdown() + assert result.response == pd.DataFrame().to_markdown() def test_poll_for_result_failed(genie, mock_workspace_client): @@ -134,7 +134,7 @@ def test_poll_for_result_max_iterations(genie, mock_workspace_client): }, ] result = genie.poll_for_result("123", "456") - assert result is None + assert result.response is None def test_ask_question(genie, mock_workspace_client): @@ -146,6 +146,16 @@ def test_ask_question(genie, mock_workspace_client): assert result == "Answer" +def test_ask_question_with_details(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.side_effect = [ + {"conversation_id": "123", "message_id": "456"}, + {"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]}, + ] + result = genie.ask_question_with_details("What is the meaning of life?") + assert isinstance(result, GenieResult) + assert result.response == "Answer" + + def test_parse_query_result_empty(): resp = {"manifest": {"schema": {"columns": []}}, "result": None} result = _parse_query_result(resp)