Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 25 additions & 55 deletions examples/python/a2a_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# /// script
# dependencies = [
# "a2a-sdk[http-server]",
# "mlflow",
# "openai",
# "uvicorn",
# ]
Expand All @@ -13,13 +14,12 @@

import os
import uuid
import json
import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional
from fastapi import FastAPI, Request, HTTPException, APIRouter
from fastapi.responses import JSONResponse
from openai import OpenAI
import mlflow
import mlflow.openai
from a2a.server.agent_execution.agent_executor import AgentExecutor
from a2a.server.agent_execution.context import RequestContext
from a2a.server.events.event_queue import EventQueue
Expand All @@ -32,6 +32,28 @@
# Initialize OpenAI client
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# MLflow tracing configuration
MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "timestep-a2a")
MLFLOW_TRACING_ENABLED = os.getenv("MLFLOW_TRACING_ENABLED", "true").lower() in {"1", "true", "yes"}
_MLFLOW_TRACING_CONFIGURED = False


def setup_mlflow_tracing() -> None:
"""Configure MLflow tracing for OpenAI calls."""
global _MLFLOW_TRACING_CONFIGURED
if _MLFLOW_TRACING_CONFIGURED or not MLFLOW_TRACING_ENABLED:
return

tracking_uri = os.getenv("MLFLOW_TRACKING_URI")
if tracking_uri:
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
mlflow.openai.autolog()
_MLFLOW_TRACING_CONFIGURED = True


setup_mlflow_tracing()

# Agent IDs
PERSONAL_ASSISTANT_ID = "00000000-0000-0000-0000-000000000000"
WEATHER_ASSISTANT_ID = "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF"
Expand Down Expand Up @@ -148,35 +170,6 @@ def build_system_message(agent_id: str, tools: List[Dict[str, Any]]) -> str:
# Track all task IDs per agent for listing
agent_task_ids: Dict[str, List[str]] = {}

def write_trace(task_id: str, agent_id: str, input_messages: List[Dict], input_tools: List[Dict], output_message: Dict) -> None:
"""Write trace to traces/ folder."""
traces_dir = Path("/workspace/traces")
traces_dir.mkdir(exist_ok=True)

timestamp = datetime.datetime.now().isoformat().replace(":", "-")
# Use short task_id for filename (first 8 chars)
task_id_short = task_id[:8] if task_id else "unknown"
agent_id_short = agent_id[:8] if agent_id else "unknown"
trace_file = traces_dir / f"{timestamp}_{task_id_short}_{agent_id_short}.json"

trace = {
"task_id": task_id,
"agent_id": agent_id,
"timestamp": timestamp,
"input": {
"messages": input_messages,
"tools": input_tools,
},
"output": {
"content": output_message.get("content", ""),
"tool_calls": output_message.get("tool_calls", []),
}
}

with open(trace_file, "w") as f:
json.dump(trace, f, indent=2)


class MultiAgentExecutor(AgentExecutor):
"""Agent executor that uses OpenAI directly and configures tools based on agent_id."""

Expand Down Expand Up @@ -291,29 +284,6 @@ async def execute(
# Convert OpenAI response to A2A format
assistant_content = assistant_message.content or ""

# Capture trace: input messages + output message
output_message_dict = {
"content": assistant_content,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
] if tool_calls else [],
}
write_trace(
task_id=task_id or "",
agent_id=self.agent_id,
input_messages=openai_messages_with_system,
input_tools=self.tools or [],
output_message=output_message_dict,
)

# Build A2A message using helper function
# Role.agent is the correct role for assistant messages in A2A
a2a_message = create_text_message_object(
Expand Down
3 changes: 3 additions & 0 deletions examples/python/compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ services:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- MLFLOW_TRACKING_URI=${MLFLOW_TRACKING_URI}
- MLFLOW_EXPERIMENT_NAME=${MLFLOW_EXPERIMENT_NAME}
- MLFLOW_TRACING_ENABLED=${MLFLOW_TRACING_ENABLED}
- UV_CACHE_DIR=/workspace/.cache/uv
develop:
watch:
Expand Down
87 changes: 82 additions & 5 deletions examples/python/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# /// script
# dependencies = [
# "a2a-sdk",
# "mlflow",
# "pandas",
# "mcp",
# "httpx",
# ]
Expand All @@ -19,6 +21,8 @@
import datetime
from pathlib import Path
from typing import Dict, Any, List, Optional
import mlflow
import pandas as pd
from a2a.client import ClientFactory, ClientConfig
from a2a.client.helpers import create_text_message_object
from a2a.types import TransportProtocol, Role
Expand All @@ -35,6 +39,24 @@
PERSONAL_ASSISTANT_ID = "00000000-0000-0000-0000-000000000000"
WEATHER_ASSISTANT_ID = "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF"

# MLflow configuration
MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "timestep-evals")
MLFLOW_EVAL_ENABLED = os.getenv("MLFLOW_EVAL_ENABLED", "true").lower() in {"1", "true", "yes"}
_MLFLOW_CONFIGURED = False


def setup_mlflow() -> None:
"""Configure MLflow tracking for evals."""
global _MLFLOW_CONFIGURED
if _MLFLOW_CONFIGURED:
return

tracking_uri = os.getenv("MLFLOW_TRACKING_URI")
if tracking_uri:
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
_MLFLOW_CONFIGURED = True


def write_task(task: Any, agent_id: str) -> None:
"""Write task to tasks/ folder in proper A2A Task format."""
Expand Down Expand Up @@ -135,6 +157,52 @@ def parse_tool_call(tool_call: Dict[str, Any]) -> tuple[Optional[str], Dict[str,
return tool_name, tool_args


def run_mlflow_eval(prompt: str, response: str, agent_id: str, task_id: Optional[str]) -> None:
"""Run MLflow evals and log results."""
if not MLFLOW_EVAL_ENABLED:
return

try:
setup_mlflow()
eval_df = pd.DataFrame(
[
{
"inputs": prompt,
"predictions": response,
"targets": "",
}
]
)
run_name = f"eval-{agent_id[:8]}-{task_id[:8] if task_id else 'unknown'}"

with mlflow.start_run(run_name=run_name):
mlflow.set_tags(
{
"a2a.agent_id": agent_id,
"a2a.task_id": task_id or "",
}
)
mlflow.log_text(prompt, "prompt.txt")
mlflow.log_text(response, "response.txt")
mlflow.log_metric("response_length", float(len(response)))

try:
from mlflow.metrics.genai import relevance

mlflow.evaluate(
data=eval_df,
model_type="question-answering",
targets="targets",
predictions="predictions",
extra_metrics=[relevance()],
)
except Exception as eval_error:
mlflow.log_param("eval_error", str(eval_error))
print(f"[MLflow eval skipped: {eval_error}]", file=sys.stderr)
except Exception as e:
print(f"[MLflow eval setup failed: {e}]", file=sys.stderr)


async def mcp_sampling_callback(
context: RequestContext["ClientSession", Any],
params: mcp_types.CreateMessageRequestParams,
Expand Down Expand Up @@ -303,8 +371,9 @@ async def run_client_loop(
message = create_text_message_object(role="user", content=initial_message)
print(f"\n[DEBUG: Starting to send message to A2A server]", file=sys.stderr)

async def process_with_output(a2a_client: Any, message_obj: Any, agent_id: str) -> None:
"""Process message stream and print output."""
async def process_with_output(a2a_client: Any, message_obj: Any, agent_id: str) -> str:
"""Process message stream, print output, and return final response."""
final_message = ""
async for event in a2a_client.send_message(message_obj):
task = extract_task_from_event(event)
print(f"\n[DEBUG: Received task, id={getattr(task, 'id', 'NO_ID')}, type={type(task)}]", file=sys.stderr)
Expand All @@ -326,6 +395,7 @@ async def process_with_output(a2a_client: Any, message_obj: Any, agent_id: str)

if task.status.state.value == "completed":
print("\n[Task completed]")
final_message = extract_final_message(task)
break

if task.status.state.value == "input-required":
Expand All @@ -349,10 +419,17 @@ async def process_with_output(a2a_client: Any, message_obj: Any, agent_id: str)
tool_result_msg.context_id = task.context_id

# Recursively process tool result
await process_with_output(a2a_client, tool_result_msg, agent_id)
result_message = await process_with_output(a2a_client, tool_result_msg, agent_id)
if result_message:
final_message = result_message
break

await process_with_output(a2a_client, message, agent_id)

return final_message.strip()

final_message = await process_with_output(a2a_client, message, agent_id)
if final_message:
task_id = task_ids[-1] if task_ids else None
run_mlflow_eval(initial_message, final_message, agent_id, task_id)
except Exception as e:
print(f"\n[Error in client loop: {e}]")
raise
Expand Down