Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 81 additions & 37 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,97 @@
import mlflow
from langchain_core.messages import AIMessage, BaseMessage

from databricks_ai_bridge.genie import Genie

from langchain_core.runnables import RunnableLambda

from typing import Dict, Any


class GenieAgent(RunnableLambda):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: add a docstring

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FMurray Doc string added to the class.

"""
A class that implements an agent to send user questions to Genie Space in Databricks through the Genie API.

This class implements an agent that uses the GenieAPI to send user questions to Genie Space in Databricks.
If return_metadata is False, the agent's response will be a dictionary containing a single key, 'messages',
which holds the result of the SQL query executed by the Genie Space.
If `return_metadata` is set to True, the agent's response will be a dictionary containing two keys: `messages`
and `metadata`. The `messages` key will contain only one element, similar to the previous case.
The `metadata` key will include the `GenieResponse` from the API, which will consist of the result of the SQL query,
the SQL query itself, and a brief description of what the query is doing.

Attributes:
genie_space_id (str): The ID of the Genie space created in Databricks will be called by the Genie API.
description (str): Description of the Genie space created in Databricks that will be accessed by the GenieAPI.
genie_agent_name (str): The name of the genie agent that will be displayed in the trace.
return_metadata (bool): Whether to return the GenieResponse generated by the GenieAPI when the agent is called.
genie (Genie): The Genie API class.

Methods:
invoke(state): Returns a dictionary with two possible keys: "messages" and "metadata," which contain the results
of the query executed by Genie Space and the associated metadata.

Examples:
>>> genie_agent = GenieAgent("01ef92421857143785bb9e765454520f")
>>> genie_agent.invoke({"messages": [{"role": "user", "content": "What is the average total invoice across the different customers?"}]})
{'messages': [AIMessage(content='| | average_total_invoice |\n|---:|------------------------:|\n| 0 | 195.648 |',
additional_kwargs={}, response_metadata={})]}
>>> genie_agent = GenieAgent("01ef92421857143785bb9e765454520f", return_metadata=True)
>>> genie_agent.invoke({"messages": [{"role": "user", "content": "What is the average total invoice across the different customers?"}]})
{'messages': [AIMessage(content='| | avg_total_invoice |\n|---:|--------------------:|\n| 0 | 195.648 |',
additional_kwargs={}, response_metadata={})],
'metadata': GenieResponse(result='| | avg_total_invoice |\n|---:|--------------------:|\n| 0 | 195.648 |',
query='SELECT AVG(`total_invoice`) AS avg_total_invoice FROM `finance`.`external_customers`.`invoices`',
description='This query calculates the average total invoice amount from all customer invoices, providing insight into overall billing trends.')}
"""
def __init__(self, genie_space_id: str,
genie_agent_name: str = "Genie",
description: str = "",
return_metadata: bool = False):
self.genie_space_id = genie_space_id
self.genie_agent_name = genie_agent_name
self.description = description
self.return_metadata = return_metadata
self.genie = Genie(genie_space_id)
super().__init__(self._call_genie_api, name=genie_agent_name)

@mlflow.trace()
def _concat_messages_array(messages):
concatenated_message = "\n".join(
[
f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}"
if isinstance(message, dict)
else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}"
for message in messages
]
)
return concatenated_message
@mlflow.trace()
def _concat_messages_array(self, messages):

data = []

@mlflow.trace()
def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
from langchain_core.messages import AIMessage
for message in messages:
if isinstance(message, dict):
data.append(f"{message.get('role', 'unknown')}: {message.get('content', '')}")
elif isinstance(message, BaseMessage):
data.append(f"{message.type}: {message.content}")
else:
data.append(f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}")

genie = Genie(genie_space_id)
concatenated_message = "\n".join([e for e in data if e])

message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n"
return concatenated_message

# Concatenate messages to form the chat history
message += _concat_messages_array(input.get("messages"))
@mlflow.trace()
def _call_genie_api(self, state: Dict[str, Any]):
message = (f"I will provide you a chat history, where your name is {self.genie_agent_name}. "
f"Please help with the described information in the chat history.\n")

# Send the message and wait for a response
genie_response = genie.ask_question(message)
# Concatenate messages to form the chat history
message += self._concat_messages_array(state.get("messages"))

if query_result := genie_response.result:
return {"messages": [AIMessage(content=query_result)]}
else:
return {"messages": [AIMessage(content="")]}
# Send the message and wait for a response
genie_response = self.genie.ask_question(message)

content = ""
metadata = None

@mlflow.trace(span_type="AGENT")
def GenieAgent(genie_space_id, genie_agent_name: str = "Genie", description: str = ""):
"""Create a genie agent that can be used to query the API"""
from functools import partial
if genie_response.result:
content = genie_response.result
metadata = genie_response

from langchain_core.runnables import RunnableLambda
if self.return_metadata:
return {"messages": [AIMessage(content=content)], "metadata": metadata}

# Create a partial function with the genie_space_id pre-filled
partial_genie_agent = partial(
_query_genie_as_agent,
genie_space_id=genie_space_id,
genie_agent_name=genie_agent_name,
)
return {"messages": [AIMessage(content=content)]}

# Use the partial function in the RunnableLambda
return RunnableLambda(partial_genie_agent)
107 changes: 82 additions & 25 deletions integrations/langchain/tests/unit_tests/test_genie.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
from unittest.mock import patch

from databricks_ai_bridge.genie import GenieResponse
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage

from databricks_langchain.genie import (
GenieAgent,
_concat_messages_array,
_query_genie_as_agent,
)
from databricks_langchain.genie import GenieAgent

import pytest

def test_concat_messages_array():

@pytest.fixture
def agent():
return GenieAgent("id-1", "Genie")


@pytest.fixture
def agent_with_metadata():
return GenieAgent("id-1", "Genie", return_metadata=True)


def test_concat_messages_array_base_messages(agent):
messages = [HumanMessage("What is the weather?"), AIMessage("It is sunny.")]

result = agent._concat_messages_array(messages)

expected_result = "human: What is the weather?\nai: It is sunny."

assert result == expected_result


def test_concat_messages_array(agent):
# Test a simple case with multiple messages
messages = [
{"role": "user", "content": "What is the weather?"},
{"role": "assistant", "content": "It is sunny."},
]
result = _concat_messages_array(messages)
result = agent._concat_messages_array(messages)
expected = "user: What is the weather?\nassistant: It is sunny."
assert result == expected

# Test case with missing content
messages = [{"role": "user"}, {"role": "assistant", "content": "I don't know."}]
result = _concat_messages_array(messages)
result = agent._concat_messages_array(messages)
expected = "user: \nassistant: I don't know."
assert result == expected

Expand All @@ -36,37 +54,76 @@ def __init__(self, role, content):
Message("user", "Tell me a joke."),
Message("assistant", "Why did the chicken cross the road?"),
]
result = _concat_messages_array(messages)
result = agent._concat_messages_array(messages)
expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?"
assert result == expected


@patch("databricks_langchain.genie.Genie")
def test_query_genie_as_agent(MockGenie):
# Mock the Genie class and its response
mock_genie = MockGenie.return_value
mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.")
@patch("databricks_ai_bridge.genie.Genie.ask_question")
def test_query_genie_as_agent(mock_ask_question, agent):

genie_response = GenieResponse(result="It is sunny.")

mock_ask_question.return_value = genie_response

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}
result = _query_genie_as_agent(input_data, "space-id", "Genie")

result = agent._call_genie_api(input_data)

expected_message = {"messages": [AIMessage(content="It is sunny.")]}

assert result == expected_message

# Test the case when genie_response is empty
mock_genie.ask_question.return_value = GenieResponse(result=None)
result = _query_genie_as_agent(input_data, "space-id", "Genie")
genie_empty_response = GenieResponse(result=None)

mock_ask_question.return_value = genie_empty_response

result = agent._call_genie_api(input_data)

expected_message = {"messages": [AIMessage(content="")]}

assert result == expected_message


@patch("langchain_core.runnables.RunnableLambda")
def test_create_genie_agent(MockRunnableLambda):
mock_runnable = MockRunnableLambda.return_value
@patch("databricks_ai_bridge.genie.Genie.ask_question")
def test_query_genie_as_agent_with_metadata(mock_ask_question, agent_with_metadata):

genie_response = GenieResponse(result="It is sunny.", query="select a from data_table", description="description")

mock_ask_question.return_value = genie_response

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}

agent = GenieAgent("space-id", "Genie")
assert agent == mock_runnable
result = agent_with_metadata._call_genie_api(input_data)

# Check that the partial function is created with the correct arguments
MockRunnableLambda.assert_called()
expected_message = {"messages": [AIMessage(content="It is sunny.")], "metadata": genie_response}

assert result == expected_message

# Test the case when genie_response is empty
genie_empty_response = GenieResponse(result=None)

mock_ask_question.return_value = genie_empty_response

result = agent_with_metadata._call_genie_api(input_data)

expected_message = {"messages": [AIMessage(content="")], "metadata": None}

assert result == expected_message


@patch("databricks_ai_bridge.genie.Genie.ask_question")
def test_query_genie_as_agent_invoke(mock_ask_question, agent):

genie_response = GenieResponse(result="It is sunny.", query="select a from data_table", description="description")

mock_ask_question.return_value = genie_response

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}

result = agent.invoke(input_data)

expected_message = {"messages": [AIMessage(content="It is sunny.")]}

assert result == expected_message