Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# OpenAI-compatible LLM (using Gemini via OpenAI API)
OPENAI_BASE_URL="https://generativelanguage.googleapis.com/v1beta/openai/"
OPENAI_API_KEY="..." # Or use GEMINI_API_KEY or GOOGLE_API_KEY
GEMINI_API_KEY="..." # Or use GOOGLE_API_KEY

# Model selection (see https://ai.google.dev/gemini-api/docs/models)
# Stable: gemini-2.5-pro, gemini-2.5-flash, gemini-2.5-flash-lite
Expand Down
42 changes: 15 additions & 27 deletions aieng-eval-agents/aieng/agent_evals/async_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
like OpenAI to prevent event loop conflicts during Gradio's hot-reload process.
"""

import logging
import sqlite3
from pathlib import Path
from typing import Any

from aieng.agent_evals.configs import Configs
from langfuse import Langfuse
from openai import AsyncOpenAI


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)


class SQLiteConnection:
Expand All @@ -27,7 +31,7 @@ def __init__(self, db_path: Path) -> None:
self.db_path = db_path
self.connection = sqlite3.connect(db_path)

def execute(self, query: str) -> list[Any]:
def execute(self, query: str) -> list[Any] | str:
"""Execute a SQLite query.

Parameters
Expand All @@ -37,11 +41,16 @@ def execute(self, query: str) -> list[Any]:

Returns
-------
list[Any]
list[Any] | str
The result of the query. Will return the result of
`execute(query).fetchall()`.
Returns a string with an error message if the query fails.
"""
return self.connection.execute(query).fetchall()
try:
return self.connection.execute(query).fetchall()
except Exception as e:
logger.exception(f"Error executing query: {e}")
return [str(e)]

def close(self) -> None:
"""Close the SQLite connection."""
Expand All @@ -63,7 +72,7 @@ class AsyncClientManager:
--------
>>> manager = AsyncClientManager()
>>> # Access clients (created on first access)
>>> openai = manager.openai_client
>>> sqlite_connection = manager.sqlite_connection("my_sqlite.db")
>>> langfuse = manager.langfuse_client
>>> # In finally block or cleanup
>>> await manager.close()
Expand Down Expand Up @@ -94,7 +103,6 @@ def __init__(self, configs: Configs | None = None) -> None:
is created.
"""
self._configs: Configs | None = configs
self._openai_client: AsyncOpenAI | None = None
self._sqlite_connection: SQLiteConnection | None = None
self._langfuse_client: Langfuse | None = None
self._otel_instrumented: bool = False
Expand All @@ -113,22 +121,6 @@ def configs(self) -> Configs:
self._configs = Configs() # type: ignore[call-arg]
return self._configs

@property
def openai_client(self) -> AsyncOpenAI:
"""Get or create OpenAI client.

Returns
-------
AsyncOpenAI
The OpenAI async client instance.
"""
if self._openai_client is None:
api_key = self.configs.openai_api_key.get_secret_value()

self._openai_client = AsyncOpenAI(api_key=api_key, base_url=self.configs.openai_base_url)
self._initialized = True
return self._openai_client

def sqlite_connection(self, db_path: Path) -> SQLiteConnection:
"""Get or create SQLite session.

Expand Down Expand Up @@ -192,13 +184,9 @@ def otel_instrumented(self, value: bool) -> None:
async def close(self) -> None:
"""Close all initialized async clients.

This method closes the OpenAI client, SQLite connection, and Langfuse
This method closes the SQLite connection, and Langfuse
client if they have been initialized.
"""
if self._openai_client is not None:
await self._openai_client.close()
self._openai_client = None

if self._sqlite_connection is not None:
self._sqlite_connection.close()
self._sqlite_connection = None
Expand Down
10 changes: 3 additions & 7 deletions aieng-eval-agents/aieng/agent_evals/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,9 @@ class Configs(BaseSettings):
)

# === Core LLM Settings ===
openai_base_url: str = Field(
default="https://generativelanguage.googleapis.com/v1beta/openai/",
description="Base URL for OpenAI-compatible API (defaults to Gemini endpoint).",
)
openai_api_key: SecretStr = Field(
validation_alias=AliasChoices("OPENAI_API_KEY", "GEMINI_API_KEY", "GOOGLE_API_KEY"),
description="API key for OpenAI-compatible API (accepts OPENAI_API_KEY, GEMINI_API_KEY, or GOOGLE_API_KEY).",
gemini_api_key: SecretStr = Field(
validation_alias=AliasChoices("GEMINI_API_KEY", "GOOGLE_API_KEY"),
description="API key for Google/Gemini API (accepts GEMINI_API_KEY, or GOOGLE_API_KEY).",
)
default_planner_model: str = Field(
default="gemini-2.5-pro",
Expand Down
16 changes: 0 additions & 16 deletions aieng-eval-agents/aieng/agent_evals/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from pathlib import Path
from typing import Any, Literal

import logfire
import nest_asyncio
from aieng.agent_evals.async_client_manager import AsyncClientManager
from aieng.agent_evals.configs import Configs
from aieng.agent_evals.progress import track_with_progress
Expand All @@ -24,19 +22,6 @@
logger = logging.getLogger(__name__)


def configure_oai_agents_sdk(service_name: str) -> None:
"""Register Langfuse as tracing provider for OAI Agents SDK.

Parameters
----------
service_name : str
The name of the service to configure.
"""
nest_asyncio.apply()
logfire.configure(service_name=service_name, send_to_logfire=False, scrubbing=False)
logfire.instrument_openai_agents()


def set_up_langfuse_otlp_env_vars():
"""Set up environment variables for Langfuse OpenTelemetry integration.

Expand Down Expand Up @@ -71,7 +56,6 @@ def setup_langfuse_tracer(service_name: str = "aieng-eval-agents") -> "trace.Tra
tracer: OpenTelemetry Tracer
"""
set_up_langfuse_otlp_env_vars()
configure_oai_agents_sdk(service_name)

# Create a TracerProvider for OpenTelemetry
trace_provider = TracerProvider()
Expand Down
183 changes: 158 additions & 25 deletions aieng-eval-agents/aieng/agent_evals/report_generation/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,29 @@
>>> )
"""

import logging
from enum import Enum
from pathlib import Path
from typing import Any

import agents
from aieng.agent_evals.async_client_manager import AsyncClientManager
from aieng.agent_evals.langfuse import setup_langfuse_tracer
from aieng.agent_evals.report_generation.file_writer import ReportFileWriter
from google.adk.agents import Agent
from google.adk.events.event import Event
from pydantic import BaseModel


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)


def get_report_generation_agent(
instructions: str,
sqlite_db_path: Path,
reports_output_path: Path,
langfuse_project_name: str | None,
) -> agents.Agent:
) -> Agent:
"""
Define the report generation agent.

Expand Down Expand Up @@ -54,29 +63,153 @@ def get_report_generation_agent(
client_manager = AsyncClientManager.get_instance()
report_file_writer = ReportFileWriter(reports_output_path)

# Define an agent using the OpenAI Agent SDK
return agents.Agent(
name="Report Generation Agent", # Agent name for logging and debugging purposes
instructions=instructions, # System instructions for the agent
# Tools available to the agent
# We wrap the `execute_sql_query` and `write_report_to_file` methods
# with `function_tool`, which will construct the tool definition JSON
# schema by extracting the necessary information from the method
# signature and docstring.
# Define an agent using Google ADK
return Agent(
name="ReportGenerationAgent",
model=client_manager.configs.default_worker_model,
instruction=instructions,
tools=[
agents.function_tool(
client_manager.sqlite_connection(sqlite_db_path).execute,
name_override="execute_sql_query",
description_override="Execute a SQL query against the SQLite database.",
),
agents.function_tool(
report_file_writer.write,
name_override="write_report_to_file",
description_override="Write the report data to a downloadable XLSX file.",
),
client_manager.sqlite_connection(sqlite_db_path).execute,
report_file_writer.write,
],
model=agents.OpenAIChatCompletionsModel(
model=client_manager.configs.default_worker_model,
openai_client=client_manager.openai_client,
),
)


class EventType(Enum):
"""Types of events from agents."""

FINAL_RESPONSE = "final_response"
TOOL_CALL = "tool_call"
THOUGHT = "thought"
TOOL_RESPONSE = "tool_response"


class ParsedEvent(BaseModel):
"""Parsed event from an agent."""

type: EventType
text: str
arguments: Any | None = None


class EventParser:
"""Parser for agent events."""

@classmethod
def parse(cls, event: Event) -> list[ParsedEvent]:
"""Parse an agent event into a list of parsed events.

The event can be a final response, a thought, a tool call,
or a tool response.

Parameters
----------
event : Event
The event to parse.

Returns
-------
list[ParsedEvent]
A list of parsed events.
"""
parsed_events = []

if event.is_final_response():
parsed_events.extend(cls._parse_final_response(event))

elif event.content:
if event.content.role == "model":
parsed_events.extend(cls._parse_model_response(event))

elif event.content.role == "user":
parsed_events.extend(cls._parse_user_response(event))

else:
logger.warning(f"Unknown stream event: {event}")

return parsed_events

@classmethod
def _parse_final_response(cls, event: Event) -> list[ParsedEvent]:
if (
not event.content
or not event.content.parts
or len(event.content.parts) == 0
or not event.content.parts[0].text
):
logger.warning(f"Final response's content is not valid: {event}")
return []

return [
ParsedEvent(
type=EventType.FINAL_RESPONSE,
text=event.content.parts[0].text,
)
]

@classmethod
def _parse_model_response(cls, event: Event) -> list[ParsedEvent]:
if not event.content or not event.content.parts:
logger.warning(f"Model response's content is not valid: {event}")
return []

parsed_events = []

for part in event.content.parts:
# Parsing tool calls and their arguments
if part.function_call:
if not part.function_call.name:
logger.warning(f"No name in function call: {part}")
continue

parsed_events.append(
ParsedEvent(
type=EventType.TOOL_CALL,
text=part.function_call.name,
arguments=part.function_call.args,
)
)

# Parsing the agent's thoughts
elif part.thought_signature or (part.text and not part.thought_signature):
if not part.text:
logger.warning(f"No text in part: {part}")
continue

parsed_events.append(
ParsedEvent(
type=EventType.THOUGHT,
text=part.text,
)
)

else:
logger.warning(f"Unknown part type: {part}")

return parsed_events

@classmethod
def _parse_user_response(cls, event: Event) -> list[ParsedEvent]:
if not event.content or not event.content.parts:
logger.warning(f"Model response's content is not valid: {event}")
return []

parsed_events = []

for part in event.content.parts:
if part.function_response:
if not part.function_response.name:
logger.warning(f"No name in function response: {part}")
continue

parsed_events.append(
ParsedEvent(
type=EventType.TOOL_RESPONSE,
text=part.function_response.name,
arguments=part.function_response.response,
)
)
else:
logger.warning(f"Unknown part type: {part}")

return parsed_events
Loading