Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,38 @@ If you are working with integration packages install them as well
```sh
pip install -e "integrations/langchain[dev]"
```

## Running tests

To run tests for the bridge library, use the following command:

```sh
pytest tests
```

To run tests for the langchain integration library, use the following command:

```sh
cd integrations/langchain
pytest tests
```

## Linting and formatting

For formatting code run:

```sh
ruff format
```

For checking linting rules run (required for CI checks):

```sh
ruff check .
```

For fixing the linting issues such as import order, etc run:

```sh
ruff check . --fix
```
22 changes: 22 additions & 0 deletions integrations/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ pip install databricks-langchain

### Install from source


With https:

```sh
pip install git+https://github.com/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain
```

With SSH:

```sh
pip install git+ssh://[email protected]/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain
```
Expand All @@ -36,3 +45,16 @@ from databricks_langchain.genie import GenieAgent

genie_agent = GenieAgent("space-id", "Genie", description="This Genie space has access to sales data in Europe")
```

### (Preview) Use a Genie space as an tool

> [!NOTE]
> Requires Genie API Private Preview. Reach out to your account team for enablement.

Once the genie tool is created, you can then bind it to a [AgentExecutor](https://python.langchain.com/docs/how_to/agent_executor/#tools) or Langgraph React Agent.

```python
from databricks_langchain.genie import GenieTool

genie_tool = GenieTool("space-id", "Genie", "This Genie space has access to sales data in Europe")
```
70 changes: 69 additions & 1 deletion integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from databricks_ai_bridge.genie import Genie
import uuid
from typing import Optional, Tuple, Type

from databricks_ai_bridge.genie import Genie, GenieResult
from pydantic import BaseModel, Field


def _concat_messages_array(messages):
Expand Down Expand Up @@ -32,6 +36,70 @@ def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
return {"messages": [AIMessage(content="")]}


class GenieToolInput(BaseModel):
question: str = Field(description="question to ask the agent")
summarized_chat_history: str = Field(
description="summarized chat history to provide the agent context of what may have been talked about. "
"Say 'No history' if there is no history to provide."
)


def GenieTool(genie_space_id: str, genie_agent_name: str, genie_space_description: str):
from langchain_core.callbacks.manager import CallbackManagerForToolRun
from langchain_core.tools import BaseTool

genie = Genie(genie_space_id)

class GenieQuestionToolWithTrace(BaseTool):
name: str = f"{genie_agent_name}_details"
description: str = genie_space_description
args_schema: Type[BaseModel] = GenieToolInput
response_format: str = "content_and_artifact"

def _run(
self,
question: str,
summarized_chat_history: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[str, Optional[GenieResult]]:
message = (
f"I will provide you a chat history, where your name is {genie_agent_name}. "
f"Please answer the following question: {question} with the following chat history "
f"for context: {summarized_chat_history}.\n"
)
response = genie.ask_question_with_details(message)
if response:
return response.response, response
return "", None

tool_with_details = GenieQuestionToolWithTrace()

class GenieQuestionToolCall(BaseTool):
name: str = genie_agent_name
description: str = genie_space_description
args_schema: Type[BaseModel] = GenieToolInput

def _run(
self,
question: str,
summarized_chat_history: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[str, GenieResult]:
tool_result = tool_with_details.invoke(
{
"args": {
"question": question,
"summarized_chat_history": summarized_chat_history,
},
"id": str(uuid.uuid4()),
"type": "tool_call",
}
)
return tool_result.content

return GenieQuestionToolCall()


def GenieAgent(genie_space_id, genie_agent_name="Genie", description=""):
"""Create a genie agent that can be used to query the API"""
from functools import partial
Expand Down
41 changes: 41 additions & 0 deletions integrations/langchain/tests/test_genie.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from unittest.mock import patch

from databricks_ai_bridge.genie import GenieResult
from langchain_core.messages import AIMessage

from databricks_langchain.genie import (
GenieAgent,
GenieTool,
GenieToolInput,
_concat_messages_array,
_query_genie_as_agent,
)
Expand Down Expand Up @@ -69,3 +72,41 @@ def test_create_genie_agent(MockRunnableLambda):

# Check that the partial function is created with the correct arguments
MockRunnableLambda.assert_called()


@patch("databricks_langchain.genie.Genie")
def test_create_genie_tool(MockGenie):
mock_genie = MockGenie.return_value
mock_genie.ask_question_with_details.return_value = GenieResult(
description=None, sql_query=None, response="It is sunny."
)

agent = GenieTool("space-id", "Genie", "Description")

assert agent.name == "Genie"
assert agent.args_schema == GenieToolInput
assert agent.description == "Description"
assert (
agent.invoke({"question": "What is the weather?", "summarized_chat_history": "No history"})
== "It is sunny."
)

assert mock_genie.ask_question_with_details.call_count == 1


@patch("databricks_langchain.genie.Genie")
def test_create_genie_tool_no_response(MockGenie):
mock_genie = MockGenie.return_value
mock_genie.ask_question_with_details.return_value = None

agent = GenieTool("space-id", "Genie", "Description")

assert agent.name == "Genie"
assert agent.args_schema == GenieToolInput
assert agent.description == "Description"
assert (
agent.invoke({"question": "What is the weather?", "summarized_chat_history": "No history"})
== ""
)

assert mock_genie.ask_question_with_details.call_count == 1
34 changes: 27 additions & 7 deletions src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Union
from typing import Optional, Union

import pandas as pd
import tiktoken
Expand All @@ -11,6 +12,13 @@
MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS


@dataclass(frozen=True, repr=True)
class GenieResult:
description: Optional[str]
sql_query: Optional[str]
response: Optional[str] # can be a query result or a text result


# Define a function to count tokens
def _count_tokens(text):
encoding = tiktoken.encoding_for_model("gpt-4o")
Expand Down Expand Up @@ -92,7 +100,7 @@ def create_message(self, conversation_id, content):
return resp

def poll_for_result(self, conversation_id, message_id):
def poll_result():
def poll_result() -> Optional[GenieResult]:
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
iteration_count += 1
Expand All @@ -107,7 +115,9 @@ def poll_result():
sql = query.get("query", "")
logging.debug(f"Description: {description}")
logging.debug(f"SQL: {sql}")
return poll_query_results()
return GenieResult(
sql_query=sql, description=description, response=poll_query_results()
)
elif resp["status"] == "COMPLETED":
# Check if there is a query object in the attachments for the COMPLETED status
query_attachment = next((r for r in resp["attachments"] if "query" in r), None)
Expand All @@ -117,12 +127,15 @@ def poll_result():
sql = query.get("query", "")
logging.debug(f"Description: {description}")
logging.debug(f"SQL: {sql}")
return poll_query_results()
return GenieResult(
sql_query=sql, description=description, response=poll_query_results()
)
else:
# Handle the text object in the COMPLETED status
return next(r for r in resp["attachments"] if "text" in r)["text"][
text_content = next(r for r in resp["attachments"] if "text" in r)["text"][
"content"
]
return GenieResult(sql_query=None, description=None, response=text_content)
elif resp["status"] == "FAILED":
logging.debug("Genie failed to execute the query")
return None
Expand Down Expand Up @@ -157,7 +170,14 @@ def poll_query_results():

return poll_result()

def ask_question(self, question):
def ask_question(self, question) -> Optional[str]:
resp = self.start_conversation(question)
genie_result = self.poll_for_result(resp["conversation_id"], resp["message_id"])
# ask question will just return the string response
if genie_result:
return genie_result.response
return None

def ask_question_with_details(self, question: str) -> Optional[GenieResult]:
resp = self.start_conversation(question)
# TODO (prithvi): return the query and the result
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
20 changes: 15 additions & 5 deletions tests/databricks_ai_bridge/test_genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
import pytest

from databricks_ai_bridge.genie import Genie, _count_tokens, _parse_query_result
from databricks_ai_bridge.genie import Genie, GenieResult, _count_tokens, _parse_query_result


@pytest.fixture
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_poll_for_result_completed_with_text(genie, mock_workspace_client):
{"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]},
]
result = genie.poll_for_result("123", "456")
assert result == "Result"
assert result.response == "Result"


def test_poll_for_result_completed_with_query(genie, mock_workspace_client):
Expand All @@ -65,7 +65,7 @@ def test_poll_for_result_completed_with_query(genie, mock_workspace_client):
},
]
result = genie.poll_for_result("123", "456")
assert result == pd.DataFrame().to_markdown()
assert result.response == pd.DataFrame().to_markdown()


def test_poll_for_result_executing_query(genie, mock_workspace_client):
Expand All @@ -82,7 +82,7 @@ def test_poll_for_result_executing_query(genie, mock_workspace_client):
},
]
result = genie.poll_for_result("123", "456")
assert result == pd.DataFrame().to_markdown()
assert result.response == pd.DataFrame().to_markdown()


def test_poll_for_result_failed(genie, mock_workspace_client):
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_poll_for_result_max_iterations(genie, mock_workspace_client):
},
]
result = genie.poll_for_result("123", "456")
assert result is None
assert result.response is None


def test_ask_question(genie, mock_workspace_client):
Expand All @@ -146,6 +146,16 @@ def test_ask_question(genie, mock_workspace_client):
assert result == "Answer"


def test_ask_question_with_details(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"conversation_id": "123", "message_id": "456"},
{"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]},
]
result = genie.ask_question_with_details("What is the meaning of life?")
assert isinstance(result, GenieResult)
assert result.response == "Answer"


def test_parse_query_result_empty():
resp = {"manifest": {"schema": {"columns": []}}, "result": None}
result = _parse_query_result(resp)
Expand Down