Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions examples/healthcare/mock_patient_taskgroup_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from dataclasses import dataclass
from typing import Literal

from livekit.agents import Agent, AgentTask
from livekit.agents.beta.workflows import TaskGroup
from livekit.agents.llm import function_tool


@dataclass
class UserData:
verified_intent: str | None = None
full_name: str | None = None
date_of_birth: str | None = None
preferred_date_time: str | None = None
status: str | None = None


class VerifyIntentTask(AgentTask[str]):
def __init__(self) -> None:
super().__init__(
instructions=("Verify the user's intent to schedule a patient appointment.")
)

async def on_enter(self) -> None:
await self.session.generate_reply(
instructions=(
"Ask user if they want to schedule an appointment. That said, do not say anything more. Just one brief sentence is enough."
),
tool_choice="none",
)

@function_tool()
async def verify_intent(self, intent: Literal["schedule"]) -> None:
self.session.userdata.verified_intent = intent
self.complete(intent)


class IdentifyPatientTask(AgentTask[dict]):
def __init__(self) -> None:
super().__init__(
instructions=(
"Ask for full name and date of birth, then call identify_patient immediately.",
),
)

async def on_enter(self) -> None:
await self.session.generate_reply(
instructions=(
"Ask for full name and date of birth, then call identify_patient immediately."
),
tool_choice="none",
)

@function_tool()
async def identify_patient(self, full_name: str, date_of_birth: str) -> None:
self.session.userdata.full_name = full_name
self.session.userdata.date_of_birth = date_of_birth
self.complete({"full_name": full_name, "date_of_birth": date_of_birth})


class SchedulePatientVisitTask(AgentTask[dict]):
def __init__(self) -> None:
super().__init__(
instructions=("Ask for a preferred date and time, then call schedule_patient_visit.")
)

async def on_enter(self) -> None:
await self.session.generate_reply(
instructions=("Ask for a preferred date and time, then call schedule_patient_visit."),
tool_choice="none",
)

@function_tool()
async def schedule_patient_visit(self, preferred_date_time: str) -> None:
self.session.userdata.preferred_date_time = preferred_date_time
self.session.userdata.status = "scheduled"
self.complete(
{
"preferred_date_time": preferred_date_time,
"status": "scheduled",
}
)


class MockPatientTaskGroupAgent(Agent):
def __init__(self) -> None:
super().__init__(
instructions=(
"You are a concise healthcare assistant. Help the user schedule a patient appointment."
)
)

async def on_enter(self) -> None:
task_group = TaskGroup(summarize_chat_ctx=True)
task_group.add(
lambda: VerifyIntentTask(),
id="verify_intent_task",
description="confirm user scheduling intent",
)
task_group.add(
lambda: IdentifyPatientTask(),
id="identify_patient_task",
description="collect patient identity details",
)
task_group.add(
lambda: SchedulePatientVisitTask(),
id="schedule_patient_visit_task",
description="schedule patient visit",
)
await task_group
104 changes: 104 additions & 0 deletions examples/healthcare/test_mock_patient_taskgroup_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import asyncio
import os
from pprint import pformat

import pytest

# Enable RunResult debug output in `run_result.py`.
os.environ.setdefault("LIVEKIT_EVALS_VERBOSE", "1")

from livekit.agents import AgentSession, inference, llm

from .mock_patient_taskgroup_agent import MockPatientTaskGroupAgent, UserData


def _pretty_events(events: list[object]) -> str:
rows: list[dict[str, object]] = []
for idx, event in enumerate(events):
row: dict[str, object] = {"index": idx, "type": getattr(event, "type", "<unknown>")}

item = getattr(event, "item", None)
if item is not None and hasattr(item, "model_dump"):
row["item"] = item.model_dump(
exclude_none=True,
exclude_defaults=True,
exclude={"id", "call_id", "created_at"},
)

old_agent = getattr(event, "old_agent", None)
new_agent = getattr(event, "new_agent", None)
if old_agent is not None or new_agent is not None:
row["handoff"] = {
"old_agent": type(old_agent).__name__ if old_agent is not None else None,
"new_agent": type(new_agent).__name__ if new_agent is not None else None,
}

rows.append(row)

return pformat(rows, width=100, compact=False, sort_dicts=False)


def _llm_model() -> llm.LLM:
return inference.LLM(
model="openai/gpt-4.1",
extra_kwargs={"temperature": 0.2},
)


@pytest.mark.asyncio
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY is required")
async def test_taskgroup_handoff_visibility_with_delay() -> None:
async def run_first_turn(
*, startup_delay: float, delay_after_wait: float
) -> tuple[list[str], list[object]]:
async with _llm_model() as model, AgentSession(llm=model, userdata=UserData()) as sess:
await sess.start(MockPatientTaskGroupAgent())
# Optionally give on_enter TaskGroup time to start first task.
if startup_delay > 0:
await asyncio.sleep(startup_delay)

result = await sess.run(user_input="Yes, please. Also who are you?")
if delay_after_wait > 0:
await asyncio.sleep(delay_after_wait)

# Accessing `expect` triggers LIVEKIT_EVALS_VERBOSE debug output.

event_types = [event.type for event in result.events]
print(
f"\nPython run #1 events (startup_delay={startup_delay}s, delay_after_wait={delay_after_wait}s):"
)
print(_pretty_events(list(result.events)))

# The first run may either call a tool immediately or ask a follow-up question first.
if "function_call" in event_types:
result.expect.contains_function_call(name="verify_intent")
await asyncio.sleep(2)

try:
result2 = await sess.run(
user_input="My name is Alice Johnson and my date of birth is 1991-04-11."
)
print(
f"\nPython run #2 events (startup_delay={startup_delay}s, delay_after_wait={delay_after_wait}s):"
)
print(_pretty_events(list(result2.events)))
except RuntimeError as e:
print(
f"\nPython run #2 raised RuntimeError "
f"(startup_delay={startup_delay}s, delay_after_wait={delay_after_wait}s): {e}"
)

return event_types, list(result.events)

no_delay_types, no_delay_events = await run_first_turn(startup_delay=1.0, delay_after_wait=0.0)
delayed_types, delayed_events = await run_first_turn(startup_delay=1.0, delay_after_wait=1.0)
cold_start_types, cold_start_events = await run_first_turn(
startup_delay=0.0, delay_after_wait=0.0
)

assert len(no_delay_events) > 0
assert len(delayed_events) > 0
assert len(cold_start_events) > 0
assert len(cold_start_types) > 0
assert len(no_delay_types) > 0
assert len(delayed_types) > 0
6 changes: 5 additions & 1 deletion livekit-agents/livekit/agents/beta/workflows/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ def __init__(
return_exceptions (bool): Whether or not to directly propagate an error. When set to True, the exception is added to the results dictionary and the sequence continues. Defaults to False.
on_task_completed (Callable[]): A callable that executes upon each task completion. The callback takes in a single argument of a TaskCompletedEvent.
"""
super().__init__(instructions="*empty*", chat_ctx=chat_ctx, llm=None)
super().__init__(
instructions="You are TaskGroup, orchestrate a sequence of multiple AgentTasks.",
chat_ctx=chat_ctx,
llm=NOT_GIVEN,
)

self._summarize_chat_ctx = summarize_chat_ctx
self._return_exceptions = return_exceptions
Expand Down
Loading