diff --git a/tests/agents/__init__.py b/tests/agents/__init__.py new file mode 100644 index 0000000..eed62c1 --- /dev/null +++ b/tests/agents/__init__.py @@ -0,0 +1,3 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + diff --git a/tests/agents/jokeflow.py b/tests/agents/jokeflow.py new file mode 100644 index 0000000..78961b8 --- /dev/null +++ b/tests/agents/jokeflow.py @@ -0,0 +1,43 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +from llama_index.core.workflow import ( + Event, + StartEvent, + StopEvent, + Workflow, + step, +) + + +class JokeEvent(Event): + joke: str + + +class JokeFlow(Workflow): + @step + async def generate_joke(self, ev: StartEvent) -> JokeEvent: + topic = ev.topic + + response = "this is a joke about " + topic + asyncio.sleep(1) + return JokeEvent(joke=str(response)) + + @step + async def critique_joke(self, ev: JokeEvent) -> StopEvent: + joke = ev.joke + + response = "this is a critique of the joke: " + joke + return StopEvent(result=str(response)) + + +async def main(): + w = JokeFlow(timeout=60, verbose=False) + result = await w.run(topic="pirates") + print(str(result)) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/agents/jokeflow_manifest.json b/tests/agents/jokeflow_manifest.json new file mode 100644 index 0000000..2b15830 --- /dev/null +++ b/tests/agents/jokeflow_manifest.json @@ -0,0 +1,70 @@ +{ + "authors": ["Cisco Systems Inc."], + "annotations": { + "type": "llama-index" + }, + "created_at": "2025-05-21T00:00:00Z", + "name": "org.agntcy.jokeflow", + "description": "create and review Jokes", + "version": "0.0.1", + "schema_version": "0.3.1", + "skills": [], + "locators": [], + "extensions": [{ + "name": "oasf.agntcy.org/feature/runtime/manifest", + "data": { + "acp": { + "capabilities": { + "threads": false, + "interrupts": false, + "callbacks": false + }, + "input": { + "properties": { + "topic": { + "description": "The topic of the Joke", + "title": "topic", + "type": "string" + } + }, + "required": ["topic"], + "title": "StartEvent", + "type": "object" + }, + "output": { + "properties": { + "joke": { + "description": "the created Joke", + "title": "joke", + "type": "string" + } + }, + "required": ["joke"], + "title": "JokeEvent", + "type": "object" + }, + "config": { + "properties": {}, + "title": "ConfigSchema", + "type": "object" + } + }, + "deployment": { + "deployment_options": [{ + "type": "source_code", + "name": "source_code_local", + "url": "./../", + "framework_config": { + "framework_type": "llamaindex", + "path": "jokeflow:jokeflow_workflow" + } + } + ], + "env_vars": [], + "dependencies": [] + } + }, + "version": "v0.2.2" + } + ] +} diff --git a/tests/agents/jokereviewer.py b/tests/agents/jokereviewer.py new file mode 100644 index 0000000..2353476 --- /dev/null +++ b/tests/agents/jokereviewer.py @@ -0,0 +1,111 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 +import asyncio + +from dotenv import find_dotenv, load_dotenv +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) + +load_dotenv(dotenv_path=find_dotenv(usecwd=True)) + + +class SecondEvent(Event): + joke: str + ai_answer: str + + +class FirstInterruptEvent(Event): + joke: str + first_question: str + needs_answer: bool + + +class InterruptResponseEvent(Event): + answer: str + + +class JokeEvent(Event): + joke: str + + +class JokeReviewer(Workflow): + @step + async def generate_joke(self, ev: StartEvent) -> JokeEvent: + topic = ev.topic + response = "this is a joke about " + topic + await asyncio.sleep(1) + return JokeEvent(joke=str(response)) + + @step + async def step_interrupt_one(self, ctx: Context, ev: JokeEvent) -> SecondEvent: + print(f"> step_interrupt_one input : {ev.joke}") + await asyncio.sleep(1) + # wait until we see a HumanResponseEvent + needs_answer = True + response = await ctx.wait_for_event( + InterruptResponseEvent, + waiter_id="waiter_step_interrupt_one", + waiter_event=FirstInterruptEvent( + joke=ev.joke, + first_question=f"What is your review about the Joke '{ev.joke}'?", + needs_answer=needs_answer, + ), + ) + + print(f"> receive response : {response.answer}") + if needs_answer and not response.answer: + raise ValueError("This needs a non-empty answer") + + return SecondEvent( + joke=ev.joke, ai_answer=f"Received human answer: {response.answer}" + ) + + @step + async def critique_joke(self, ev: SecondEvent) -> StopEvent: + joke = ev.joke + + await asyncio.sleep(1) + response = "this is a review for the joke: " + joke + "\n" + ev.ai_answer + result = { + joke: joke, + "review": str(response), + } + return StopEvent(result=result) + + +def interrupt_workflow() -> JokeReviewer: + joke_reviewer = JokeReviewer(timeout=None, verbose=True) + return joke_reviewer + + +async def main(): + print("Joke Reviewer with Interrupts") + workflow = interrupt_workflow() + + handler = workflow.run(topic="pirates") + + print("Reading events from the workflow...") + async for ev in handler.stream_events(): + if isinstance(ev, FirstInterruptEvent): + # capture keyboard input + response = input(ev.first_question) + print("Sending response event...") + handler.ctx.send_event( + InterruptResponseEvent( + answer=response, + ) + ) + + print("waiting for final result...") + final_result = await handler + print("Final result: ", final_result) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/agents/jokereviewer_manifest.json b/tests/agents/jokereviewer_manifest.json new file mode 100644 index 0000000..129f0ca --- /dev/null +++ b/tests/agents/jokereviewer_manifest.json @@ -0,0 +1,129 @@ +{ + "authors": ["Cisco Systems Inc."], + "annotations": { + "type": "llama-index" + }, + "created_at": "2025-05-21T00:00:00Z", + "name": "org.agntcy.jokereviewer", + "description": "create and review Jokes", + "version": "0.0.1", + "schema_version": "0.3.1", + "skills": [], + "locators": [], + "extensions": [{ + "name": "oasf.agntcy.org/feature/runtime/manifest", + "data": { + "acp": { + "input": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "The human input" + } + } + }, + "output": { + "type": "object", + "oneOf": [{ + "properties": { + "human_answer": { + "type": "string" + } + }, + "required": [ + "human_answer" + ] + }, { + "properties": { + "ai_answer": { + "type": "string" + } + }, + "required": [ + "ai_answer" + ] + } + ] + }, + "config": { + "type": "object", + "properties": {} + }, + "capabilities": { + "threads": false, + "interrupts": true, + "callbacks": false + }, + "interrupts": [{ + "interrupt_type": "first_interrupt", + "interrupt_payload": { + "type": "object", + "title": "First interrupt", + "description": "First interrupt the agent is asking", + "properties": { + "joke": { + "title": "Joke", + "description": "The joke to be reviewed", + "type": "string" + }, + "first_question": { + "title": "First question", + "description": "Natural language question that is going to be asked by this agent", + "type": "string" + }, + "needs_answer": { + "title": "Whether the agent needs an answer for this question", + "description": "True if the agent needs a non-empty answer for this question, False otherwise", + "type": "boolean" + } + }, + "required": [ + "joke", + "first_question", + "needs_answer" + ] + }, + "resume_payload": { + "type": "object", + "title": "First interrupt answer", + "description": "Answer to the first interrupt the agent asked", + "properties": { + "answer": { + "title": "Answer", + "description": "Text of the answer", + "type": "string" + } + }, + "required": [ + "answer" + ] + } + } + ] + }, + "dependencies": [], + "deployment": { + "dependencies": [], + "deployment_options": [{ + "type": "source_code", + "name": "source_code_local", + "url": "file://.", + "framework_config": { + "framework_type": "llamaindex", + "path": "JokeReviewer", + "interrupts": { + "first_interrupt": { + "interrupt_ref": "tests.agents.jokereviewer:FirstInterruptEvent", + "resume_ref": "tests.agents.jokereviewer:InterruptResponseEvent" + } + } + } + } + ] + } + }, + "version": "v0.2.2" + } + ] +} \ No newline at end of file diff --git a/tests/test_llamaindex_adapter.py b/tests/test_llamaindex_adapter.py new file mode 100644 index 0000000..c7247aa --- /dev/null +++ b/tests/test_llamaindex_adapter.py @@ -0,0 +1,321 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from uuid import uuid4 + +import pytest + +from agent_workflow_server.agents.adapters.llamaindex import ( + LlamaIndexAdapter, + LlamaIndexAgent, +) +from agent_workflow_server.storage.models import Interrupt, Run +from agent_workflow_server.storage.storage import DB +from tests.agents.jokeflow import JokeFlow +from tests.agents.jokereviewer import JokeReviewer +from tests.tools import _read_manifest + +logger = logging.getLogger(__name__) + + +MOCK_AGENT_ID = "3f1e2549-5799-4321-91ae-2a4881d55526" +MOCK_RUN_INPUT = {"topic": "pirates"} +MOCK_RUN_OUTPUT = { + "result": "this is a critique of the joke: this is a joke about pirates" +} +EVENTS_CONTENT_KEYS = [ + "accepted_events", + "broker_log", + "event_buffers", + "globals", + "in_progress", + "is_running", + "queues", + "stepwise", + "streaming_queue", +] + + +@pytest.mark.asyncio +async def test_llamaindex_astream(): + expected = [{}, "this is a critique of the joke: this is a joke about pirates"] + + w = JokeFlow(timeout=60, verbose=False) + manifest_path = "tests/agents/jokeflow_manifest.json" + manifest = _read_manifest(manifest_path) + assert manifest + + adapter = LlamaIndexAdapter() + agent = adapter.load_agent( + agent=w, + manifest=manifest.extensions[0].data.deployment, + set_thread_persistance_flag=DB.set_persist_threads, + ) + assert isinstance(agent, LlamaIndexAgent) + + thread_id = str(uuid4()) + new_run = Run( + agent_id=MOCK_AGENT_ID, + input=MOCK_RUN_INPUT, + thread_id=thread_id, + config=None, + ) + + events = [] + async for result in agent.astream(new_run): + events.append(result.data) + + assert len(events) == len(expected), ( + "Unexpected number of events generated during the run" + ) + for i, event in enumerate(events): + assert event == expected[i], ( + f"Event {i} does not match expected value: {event} != {expected[i]}" + ) + + +@pytest.mark.asyncio +async def test_llamaindex_get_agent_state(): + w = JokeFlow(timeout=60, verbose=False) + manifest_path = "tests/agents/jokeflow_manifest.json" + manifest = _read_manifest(manifest_path) + assert manifest + + adapter = LlamaIndexAdapter() + agent = adapter.load_agent( + agent=w, + manifest=manifest.extensions[0].data.deployment, + set_thread_persistance_flag=DB.set_persist_threads, + ) + assert isinstance(agent, LlamaIndexAgent) + + thread_id = str(uuid4()) + new_run = Run( + agent_id=MOCK_AGENT_ID, + input=MOCK_RUN_INPUT, + thread_id=thread_id, + config=None, + ) + + checkpoints = [] + async for _ in agent.astream(new_run): + state = await agent.get_agent_state(thread_id) + assert state is not None, "Agent state should not be None" + assert isinstance(state, dict), "Agent state should be a dictionary" + assert "checkpoint_id" in state, "Checkpoint ID should be present in the state" + assert (key in state["values"].keys() for key in EVENTS_CONTENT_KEYS), ( + "State values should contain expected keys" + ) + assert state["checkpoint_id"] not in checkpoints, ( + "Checkpoint ID should be unique" + ) + checkpoints.append(state) + + +@pytest.mark.asyncio +async def test_llamaindex_get_history(): + w = JokeFlow(timeout=60, verbose=False) + manifest_path = "tests/agents/jokeflow_manifest.json" + manifest = _read_manifest(manifest_path) + assert manifest + + adapter = LlamaIndexAdapter() + agent = adapter.load_agent( + agent=w, + manifest=manifest.extensions[0].data.deployment, + set_thread_persistance_flag=DB.set_persist_threads, + ) + assert isinstance(agent, LlamaIndexAgent) + + thread_id = str(uuid4()) + new_run = Run( + agent_id=MOCK_AGENT_ID, + input=MOCK_RUN_INPUT, + thread_id=thread_id, + config=None, + ) + + events = [] + async for result in agent.astream(new_run): + events.append(result.data) + + history = await agent.get_history(thread_id, None, None) + assert history is not None, "History should not be None" + assert isinstance(history, list), "History should be a list" + assert len(events) == len(history) + + checkpoints = [] + for item in history: + print(json.dumps(item, indent=4, sort_keys=True)) + assert isinstance(item, dict), "History item should be a dictionary" + assert "checkpoint_id" in item, "Checkpoint ID should be present" + assert (key in item["values"].keys() for key in EVENTS_CONTENT_KEYS), ( + "Hist values should contain expected keys" + ) + assert item["checkpoint_id"] not in checkpoints, ( + "Checkpoint ID should be unique" + ) + checkpoints.append(item) + + +@pytest.mark.asyncio +async def test_llamaindex_update_agent_state(): + expected = [{}, "this is a critique of the joke: this is a joke about pirates"] + + w = JokeFlow(timeout=60, verbose=False) + manifest_path = "tests/agents/jokeflow_manifest.json" + manifest = _read_manifest(manifest_path) + assert manifest + + adapter = LlamaIndexAdapter() # Replace with actual adapter initialization + agent = adapter.load_agent( + agent=w, + manifest=manifest.extensions[0].data.deployment, + set_thread_persistance_flag=DB.set_persist_threads, + ) + assert isinstance(agent, LlamaIndexAgent) + + thread_id = str(uuid4()) + new_run = Run( + agent_id=MOCK_AGENT_ID, + input=MOCK_RUN_INPUT, + thread_id=thread_id, + config=None, + ) + + events = [] + async for result in agent.astream(new_run): + events.append(result.data) + assert len(events) == len(expected), ( + "Unexpected number of events generated during the run" + ) + for i, event in enumerate(events): + assert event == expected[i], ( + f"Event {i} does not match expected value: {event} != {expected[i]}" + ) + + history = await agent.get_history(thread_id, None, None) + assert history is not None + assert isinstance(history, list) + assert len(history) == 2 + + checkpoint_id = str(uuid4()) + new_checkpoint = { + "checkpoint_id": checkpoint_id, + "values": { + "accepted_events": [ + ["generate_joke", "StartEvent"], + ["critique_joke", "JokeEvent"], + ], + "broker_log": [ + '{"__is_pydantic": true, "value": {"_data": {"topic": "duck"}}, "qualified_name": "llama_index.core.workflow.events.StartEvent"}', + '{"__is_pydantic": true, "value": {"joke": "this is a joke about duck"}, "qualified_name": "tests.agents.jokeflow.JokeEvent"}', + '{"__is_pydantic": true, "value": {}, "qualified_name": "llama_index.core.workflow.events.StopEvent"}', + ], + "event_buffers": {}, + "globals": {}, + "in_progress": {"_done": [], "critique_joke": [], "generate_joke": []}, + "is_running": False, + "queues": {"_done": "[]", "critique_joke": "[]", "generate_joke": "[]"}, + "stepwise": False, + "streaming_queue": "[]", + }, + } + result = await agent.update_agent_state(thread_id, new_checkpoint) + + history = await agent.get_history(thread_id, None, None) + assert history is not None + assert isinstance(history, list) + assert len(history) == 3 + print(json.dumps(history[2], indent=4, sort_keys=True)) + new_checkpoint["values"]["in_progress"] = {} + assert history[2]["values"] == new_checkpoint["values"], ( + "Updated checkpoint does not match the expected values" + ) + + +@pytest.mark.asyncio +async def test_llamaindex_astream_with_interrupt(): + expected_run_1 = [ + { + "joke": "this is a joke about pirates", + "first_question": "What is your review about the Joke 'this is a joke about pirates'?", + "needs_answer": True, + } + ] + expected_run_2 = [ + {}, + { + "this is a joke about pirates": "this is a joke about pirates", + "review": "this is a review for the joke: this is a joke about pirates\nReceived human answer: user_value", + }, + ] + w = JokeReviewer(timeout=60, verbose=False) + + manifest_path = "tests/agents/jokereviewer_manifest.json" + manifest = _read_manifest(manifest_path) + assert manifest + + adapter = LlamaIndexAdapter() # Replace with actual adapter initialization + agent = adapter.load_agent( + agent=w, + manifest=manifest.extensions[0].data.deployment, + set_thread_persistance_flag=DB.set_persist_threads, + ) + assert isinstance(agent, LlamaIndexAgent) + + thread_id = str(uuid4()) + new_run = Run( + agent_id=MOCK_AGENT_ID, + input=MOCK_RUN_INPUT, + thread_id=thread_id, + config=None, + interrupt=Interrupt( + event="interrupt_event", + name="first_interrupt", + ai_data={"key": "value"}, + ), + ) + + print("First run with interrupt") + events = [] + async for result in agent.astream(new_run): + events.append(result.data) + if (key in result.data for key in ["joke", "first_question", "needs_answer"]): + # this is the FirstInterruptEvent, in dict form + break + assert len(events) == len(expected_run_1), ( + "Unexpected number of events generated during the run" + ) + for i, event in enumerate(events): + assert event == expected_run_1[i], ( + f"Event {i} does not match expected value: {event} != {expected_run_1[i]}" + ) + + new_run = Run( + agent_id=MOCK_AGENT_ID, + input=MOCK_RUN_INPUT, + thread_id=thread_id, + config=None, + interrupt=Interrupt( + event="interrupt_event", + name="first_interrupt", + ai_data={"key": "value"}, + user_data={"answer": "user_value"}, + ), + ) + + print("Second run") + events = [] + async for result in agent.astream(new_run): + events.append(result.data) + + assert len(events) == len(expected_run_2), ( + "Unexpected number of events generated during the run" + ) + for i, event in enumerate(events): + assert event == expected_run_2[i], ( + f"Event {i} does not match expected value: {event} != {expected_run_2[i]}" + ) diff --git a/tests/tools.py b/tests/tools.py new file mode 100644 index 0000000..82f81cd --- /dev/null +++ b/tests/tools.py @@ -0,0 +1,25 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import os + +from agent_workflow_server.generated.manifest.models.agent_manifest import AgentManifest + +logger = logging.getLogger(__name__) + + +def _read_manifest(path: str) -> AgentManifest: + if os.path.isfile(path): + with open(path, "r") as file: + try: + manifest_data = json.load(file) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid JSON format in manifest file: {path}. Error: {e}" + ) + # print full path + logger.info(f"Loaded Agent Manifest from {os.path.abspath(path)}") + return AgentManifest.model_validate(manifest_data) + return None