diff --git a/agent_assembly/adapters/google_adk/patch.py b/agent_assembly/adapters/google_adk/patch.py index 01514fe1..81b34954 100644 --- a/agent_assembly/adapters/google_adk/patch.py +++ b/agent_assembly/adapters/google_adk/patch.py @@ -33,12 +33,19 @@ class GoogleADKPatch: process_agent_id: str | None = None def apply(self) -> bool: - """Apply patch wiring and return whether Google ADK is available.""" + """Apply patch wiring and return whether Google ADK is available. + + Patches ``BaseTool.run_async`` and every concrete tool class that + overrides ``run_async`` (e.g. ``FunctionTool``), so interception still + runs for ADK 1.x tools whose subclass shadows the base method. + """ set_process_agent_id(self.process_agent_id) tool_cls = _load_google_adk_base_tool_class() if tool_cls is None: return False _apply_tool_run_async_patch(tool_cls, self.callback_handler) + for concrete_cls in _load_google_adk_concrete_tool_classes(tool_cls): + _apply_tool_run_async_patch(concrete_cls, self.callback_handler) agent_cls = _load_google_adk_base_agent_class() if agent_cls is not None: _apply_agent_run_async_patch(agent_cls, self.process_agent_id) @@ -51,6 +58,8 @@ def revert(self) -> None: _revert_agent_run_async_patch(agent_cls) tool_cls = _load_google_adk_base_tool_class() if tool_cls is not None: + for concrete_cls in _load_google_adk_concrete_tool_classes(tool_cls): + _revert_tool_run_async_patch(concrete_cls) _revert_tool_run_async_patch(tool_cls) set_process_agent_id(None) return None @@ -68,6 +77,32 @@ def _load_google_adk_base_tool_class() -> type[Any] | None: return None +def _load_google_adk_concrete_tool_classes(base_tool_cls: type[Any]) -> list[type[Any]]: + """Return concrete ADK tool classes that OVERRIDE ``run_async``. + + Concrete ADK 1.x tools (e.g. ``FunctionTool``) define their own + ``run_async`` on the subclass, so a patch on ``BaseTool.run_async`` never + runs for them. Discover such classes in ``google.adk.tools`` so they can be + patched directly. + """ + try: + module = importlib.import_module("google.adk.tools") + except ImportError: + return [] + + concrete: list[type[Any]] = [] + for attr_value in vars(module).values(): + if not isinstance(attr_value, type): + continue + if attr_value is base_tool_cls: + continue + if not issubclass(attr_value, base_tool_cls): + continue + if "run_async" in vars(attr_value): + concrete.append(attr_value) + return concrete + + def _load_google_adk_base_agent_class() -> type[Any] | None: try: module = importlib.import_module("google.adk.agents") @@ -124,7 +159,9 @@ def _revert_agent_run_async_patch(agent_cls: type[Any]) -> None: def _apply_tool_run_async_patch(tool_cls: type[Any], callback_handler: Any) -> None: - if getattr(tool_cls, _TOOLS_PATCHED_FLAG, False): + # Check the class's OWN dict, not inherited state — a concrete subclass that + # overrides run_async must be patched even when its base is already patched. + if vars(tool_cls).get(_TOOLS_PATCHED_FLAG, False): return None original_run_async = tool_cls.run_async @@ -190,16 +227,17 @@ async def patched_run_async(self: Any, *, args: Any, tool_context: Any, **kwargs def _revert_tool_run_async_patch(tool_cls: type[Any]) -> None: - if not getattr(tool_cls, _TOOLS_PATCHED_FLAG, False): + # Inspect OWN dict so reverting one class never acts on inherited state. + if not vars(tool_cls).get(_TOOLS_PATCHED_FLAG, False): return None - original_run_async = getattr(tool_cls, _ORIGINAL_TOOL_RUN_ASYNC, None) + original_run_async = vars(tool_cls).get(_ORIGINAL_TOOL_RUN_ASYNC) if callable(original_run_async): tool_cls.run_async = original_run_async - if hasattr(tool_cls, _ORIGINAL_TOOL_RUN_ASYNC): + if _ORIGINAL_TOOL_RUN_ASYNC in vars(tool_cls): delattr(tool_cls, _ORIGINAL_TOOL_RUN_ASYNC) - if hasattr(tool_cls, _TOOLS_PATCHED_FLAG): + if _TOOLS_PATCHED_FLAG in vars(tool_cls): delattr(tool_cls, _TOOLS_PATCHED_FLAG) return None diff --git a/agent_assembly/adapters/pydantic_ai/patch.py b/agent_assembly/adapters/pydantic_ai/patch.py index 0cf80ccf..33b04182 100644 --- a/agent_assembly/adapters/pydantic_ai/patch.py +++ b/agent_assembly/adapters/pydantic_ai/patch.py @@ -18,6 +18,7 @@ from agent_assembly.core.spawn import _SPAWN_CTX, SpawnContext, spawn_context_scope _ORIGINAL_TOOL_RUN = "_agent_assembly_original_pydantic_ai_tool_run" +_ORIGINAL_TOOLSET_CALL_TOOL = "_agent_assembly_original_pydantic_ai_toolset_call_tool" _TOOLS_PATCHED_FLAG = "_agent_assembly_pydantic_ai_tools_patched" _ORIGINAL_AGENT_RUN = "_agent_assembly_original_pydantic_ai_agent_run" _ORIGINAL_AGENT_RUN_SYNC = "_agent_assembly_original_pydantic_ai_agent_run_sync" @@ -34,12 +35,28 @@ class PydanticAIPatch: process_agent_id: str | None = None def apply(self) -> bool: - """Apply patch wiring and return whether Pydantic AI is available.""" + """Apply patch wiring and return whether a tool hook was installed. + + Detects the tool-execution hook across Pydantic AI versions: the + ``Tool._run`` hook on <0.3.0 and the ``AbstractToolset.call_tool`` + hook on >=0.3.0. When neither hook point exists, this is a no-op that + returns ``False`` instead of raising ``AttributeError``. + """ set_process_agent_id(self.process_agent_id) + + tool_hooked = False tool_cls = _load_pydantic_ai_tool_class() - if tool_cls is None: + if tool_cls is not None: + tool_hooked = _apply_tool_run_patch(tool_cls, self.callback_handler) + if not tool_hooked: + toolset_cls = _load_pydantic_ai_toolset_class() + if toolset_cls is not None: + tool_hooked = _apply_toolset_call_tool_patch(toolset_cls, self.callback_handler) + + if not tool_hooked: + set_process_agent_id(None) return False - _apply_tool_run_patch(tool_cls, self.callback_handler) + agent_cls = _load_pydantic_ai_agent_class() if agent_cls is not None: _apply_agent_run_patch(agent_cls, self.process_agent_id) @@ -53,6 +70,9 @@ def revert(self) -> None: tool_cls = _load_pydantic_ai_tool_class() if tool_cls is not None: _revert_tool_run_patch(tool_cls) + toolset_cls = _load_pydantic_ai_toolset_class() + if toolset_cls is not None: + _revert_toolset_call_tool_patch(toolset_cls) set_process_agent_id(None) return None @@ -96,6 +116,23 @@ def _load_pydantic_ai_tool_class() -> type[Any] | None: return None +def _load_pydantic_ai_toolset_class() -> type[Any] | None: + """Load ``AbstractToolset`` — the >=0.3.0 tool-execution hook point. + + In Pydantic AI >=0.3.0 tool execution routes through + ``AbstractToolset.call_tool`` rather than ``Tool._run``. + """ + try: + module = importlib.import_module("pydantic_ai.toolsets") + except ImportError: + return None + + toolset_cls = getattr(module, "AbstractToolset", None) + if isinstance(toolset_cls, type): + return toolset_cls + return None + + def _load_pydantic_ai_agent_class() -> type[Any] | None: try: module = importlib.import_module("pydantic_ai") @@ -170,11 +207,14 @@ def _revert_agent_run_patch(agent_cls: type[Any]) -> None: return None -def _apply_tool_run_patch(tool_cls: type[Any], callback_handler: Any) -> None: +def _apply_tool_run_patch(tool_cls: type[Any], callback_handler: Any) -> bool: + """Patch ``Tool._run`` (the <0.3.0 hook); no-op if it is unavailable.""" if getattr(tool_cls, _TOOLS_PATCHED_FLAG, False): - return None + return True - original_run = tool_cls._run + original_run = getattr(tool_cls, "_run", None) + if not callable(original_run): + return False @wraps(original_run) async def patched_run(self: Any, ctx: Any, args: Any, **kwargs: Any) -> Any: @@ -233,7 +273,7 @@ async def patched_run(self: Any, ctx: Any, args: Any, **kwargs: Any) -> Any: setattr(tool_cls, _ORIGINAL_TOOL_RUN, original_run) tool_cls._run = patched_run setattr(tool_cls, _TOOLS_PATCHED_FLAG, True) - return None + return True def _revert_tool_run_patch(tool_cls: type[Any]) -> None: @@ -251,6 +291,90 @@ def _revert_tool_run_patch(tool_cls: type[Any]) -> None: return None +def _apply_toolset_call_tool_patch(toolset_cls: type[Any], callback_handler: Any) -> bool: + """Patch ``AbstractToolset.call_tool`` (the >=0.3.0 hook); no-op if absent.""" + if getattr(toolset_cls, _TOOLS_PATCHED_FLAG, False): + return True + + original_call_tool = getattr(toolset_cls, "call_tool", None) + if not callable(original_call_tool): + return False + + @wraps(original_call_tool) + async def patched_call_tool(self: Any, name: Any, tool_args: Any, ctx: Any, tool: Any, **kwargs: Any) -> Any: + tool_name = str(name) + serialized_args = _serialize_tool_args(tool_args) + agent_id = _resolve_agent_id(ctx) + run_id = _resolve_run_id(ctx) + + decision = await _invoke_async_tool_check( + callback_handler, + tool_name=tool_name, + tool_args=serialized_args, + agent_id=agent_id, + run_id=run_id, + ) + status, reason = _normalize_decision(decision) + is_pending_flow = False + if status == "pending": + is_pending_flow = True + timeout_seconds = _get_pending_tool_approval_timeout_seconds(callback_handler) + final_decision = await _wait_for_async_tool_approval( + callback_handler, + tool_name=tool_name, + timeout_seconds=timeout_seconds, + tool_args=serialized_args, + agent_id=agent_id, + run_id=run_id, + ) + status, reason = _normalize_decision(final_decision) + + if status == "deny": + if is_pending_flow: + raise _build_pending_rejected_error(tool_name, reason) + raise _build_denied_error(tool_name, reason) + + spawn_ctx = SpawnContext( + parent_agent_id=agent_id or "", + depth=_current_spawn_depth(), + spawned_by_tool=tool_name, + delegation_reason=f"tool:{tool_name}", + ) + with spawn_context_scope(spawn_ctx): + result = original_call_tool(self, name, tool_args, ctx, tool, **kwargs) + if inspect.isawaitable(result): + result = await result + + await _record_async_tool_result( + callback_handler, + tool_name=tool_name, + result=result, + agent_id=agent_id, + run_id=run_id, + ) + return result + + setattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL, original_call_tool) + toolset_cls.call_tool = patched_call_tool + setattr(toolset_cls, _TOOLS_PATCHED_FLAG, True) + return True + + +def _revert_toolset_call_tool_patch(toolset_cls: type[Any]) -> None: + if not getattr(toolset_cls, _TOOLS_PATCHED_FLAG, False): + return None + + original_call_tool = getattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL, None) + if callable(original_call_tool): + toolset_cls.call_tool = original_call_tool + + if hasattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL): + delattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL) + if hasattr(toolset_cls, _TOOLS_PATCHED_FLAG): + delattr(toolset_cls, _TOOLS_PATCHED_FLAG) + return None + + def set_process_agent_id(agent_id: str | None) -> None: global _PROCESS_AGENT_ID _PROCESS_AGENT_ID = agent_id diff --git a/test/unit/adapters/google_adk/test_google_adk_patch.py b/test/unit/adapters/google_adk/test_google_adk_patch.py index eb77d3c7..31b9dd78 100644 --- a/test/unit/adapters/google_adk/test_google_adk_patch.py +++ b/test/unit/adapters/google_adk/test_google_adk_patch.py @@ -362,3 +362,102 @@ def fake_import_module(module_name: str) -> object: patcher.revert() assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False + + +def _install_fake_google_adk_with_function_tool( + monkeypatch: pytest.MonkeyPatch, +) -> tuple[type[Any], type[Any]]: + """Model ADK 1.x: a concrete ``FunctionTool`` subclass that OVERRIDES + ``run_async``, so a patch on ``BaseTool.run_async`` alone never runs for it. + """ + + class FakeBaseTool: + name = "base_tool" + + async def run_async(self, *, args: Any, tool_context: Any, **kwargs: Any) -> dict[str, object]: + del kwargs + return {"who": "base", "args": args, "tool_context": tool_context} + + class FakeFunctionTool(FakeBaseTool): + name = "function_tool" + + # Concrete subclass overrides run_async (its own __dict__ entry). + async def run_async(self, *, args: Any, tool_context: Any, **kwargs: Any) -> dict[str, object]: + del kwargs + return {"who": "function", "args": args, "tool_context": tool_context} + + fake_module = SimpleNamespace(BaseTool=FakeBaseTool, FunctionTool=FakeFunctionTool) + + def fake_import_module(module_name: str) -> object: + if module_name == "google.adk.tools": + return fake_module + raise ImportError(module_name) + + monkeypatch.setattr(google_adk_patch.importlib, "import_module", fake_import_module) + return FakeBaseTool, FakeFunctionTool + + +def test_load_concrete_tool_classes_finds_run_async_overrides( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeBaseTool, FakeFunctionTool = _install_fake_google_adk_with_function_tool(monkeypatch) + + concrete = google_adk_patch._load_google_adk_concrete_tool_classes(FakeBaseTool) + + # The base class itself is excluded; the overriding subclass is included. + assert FakeFunctionTool in concrete + assert FakeBaseTool not in concrete + + +@pytest.mark.asyncio +async def test_apply_patches_concrete_function_tool_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeBaseTool, FakeFunctionTool = _install_fake_google_adk_with_function_tool(monkeypatch) + + patcher = google_adk_patch.GoogleADKPatch(_AllowInterceptor()) + assert patcher.apply() is True + + # Both the base and the concrete override are patched. + assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is True + assert getattr(FakeFunctionTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is True + + +@pytest.mark.asyncio +async def test_function_tool_override_is_intercepted_on_deny( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _, FakeFunctionTool = _install_fake_google_adk_with_function_tool(monkeypatch) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "deny", "reason": "blocked subclass"} + + patcher = google_adk_patch.GoogleADKPatch(Interceptor()) + assert patcher.apply() is True + + tool = FakeFunctionTool() + tool_context = SimpleNamespace(invocation_context=None) + # Governance runs on the SUBCLASS run_async, not just the base. + with pytest.raises(PolicyViolationError, match="blocked by governance policy: blocked subclass"): + await tool.run_async(args={"step": 1}, tool_context=tool_context) + + +@pytest.mark.asyncio +async def test_revert_restores_concrete_function_tool_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeBaseTool, FakeFunctionTool = _install_fake_google_adk_with_function_tool(monkeypatch) + original_base = FakeBaseTool.run_async + original_function = FakeFunctionTool.run_async + + patcher = google_adk_patch.GoogleADKPatch(_AllowInterceptor()) + assert patcher.apply() is True + assert FakeFunctionTool.run_async is not original_function + + patcher.revert() + assert FakeBaseTool.run_async is original_base + assert FakeFunctionTool.run_async is original_function + assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False + assert getattr(FakeFunctionTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False diff --git a/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py b/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py index ff88a760..e022ac2c 100644 --- a/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py +++ b/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py @@ -373,3 +373,129 @@ async def on_tool_end(self, **kwargs: object) -> None: run_id="run-z", ) assert observed_outputs == ["result-value"] + + +def _install_fake_pydantic_ai_v030_modules( + monkeypatch: pytest.MonkeyPatch, +) -> type[Any]: + """Model Pydantic AI >=0.3.0: ``Tool`` has no ``_run`` and tool execution + routes through ``AbstractToolset.call_tool(self, name, tool_args, ctx, tool)``. + """ + + class FakeTool: + # Note: no `_run` method — mirrors the >=0.3.0 restructure. + name = "fake_tool" + + class FakeAbstractToolset: + async def call_tool(self, name: Any, tool_args: Any, ctx: Any, tool: Any) -> dict[str, object]: + return {"name": name, "tool_args": tool_args, "ctx": ctx, "tool": tool} + + fake_tools_module = SimpleNamespace(Tool=FakeTool) + fake_toolsets_module = SimpleNamespace(AbstractToolset=FakeAbstractToolset) + + def fake_import_module(module_name: str) -> object: + if module_name == "pydantic_ai.tools": + return fake_tools_module + if module_name == "pydantic_ai.toolsets": + return fake_toolsets_module + raise ImportError(module_name) + + monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", fake_import_module) + return FakeAbstractToolset + + +@pytest.mark.asyncio +async def test_apply_patches_toolset_call_tool_for_v030(monkeypatch: pytest.MonkeyPatch) -> None: + FakeAbstractToolset = _install_fake_pydantic_ai_v030_modules(monkeypatch) + + patcher = pydantic_ai_patch.PydanticAIPatch(_RecordingInterceptor()) + assert patcher.apply() is True + assert getattr(FakeAbstractToolset, pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is True + + first_ref = FakeAbstractToolset.call_tool + # Re-applying is idempotent. + assert patcher.apply() is True + assert FakeAbstractToolset.call_tool is first_ref + + +@pytest.mark.asyncio +async def test_v030_toolset_allow_flow_runs_and_records(monkeypatch: pytest.MonkeyPatch) -> None: + FakeAbstractToolset = _install_fake_pydantic_ai_v030_modules(monkeypatch) + + recorded: list[dict[str, object]] = [] + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + async def record_result(self, **kwargs: object) -> None: + recorded.append(dict(kwargs)) + + patcher = pydantic_ai_patch.PydanticAIPatch(Interceptor()) + assert patcher.apply() is True + + toolset = FakeAbstractToolset() + ctx = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="agent-x"), run_id="run-x") + result = await toolset.call_tool("fake_tool", {"q": "hi"}, ctx, object()) + + assert result["name"] == "fake_tool" + assert len(recorded) == 1 + assert recorded[0]["tool_name"] == "fake_tool" + assert recorded[0]["agent_id"] == "agent-x" + + +@pytest.mark.asyncio +async def test_v030_toolset_deny_flow_raises_policy_violation(monkeypatch: pytest.MonkeyPatch) -> None: + FakeAbstractToolset = _install_fake_pydantic_ai_v030_modules(monkeypatch) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "deny", "reason": "blocked v030"} + + patcher = pydantic_ai_patch.PydanticAIPatch(Interceptor()) + assert patcher.apply() is True + + toolset = FakeAbstractToolset() + ctx = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="agent-x"), run_id="run-x") + with pytest.raises(PolicyViolationError, match="blocked by governance policy: blocked v030"): + await toolset.call_tool("fake_tool", {"q": "hi"}, ctx, object()) + + +def test_v030_revert_restores_toolset_call_tool(monkeypatch: pytest.MonkeyPatch) -> None: + FakeAbstractToolset = _install_fake_pydantic_ai_v030_modules(monkeypatch) + original_call_tool = FakeAbstractToolset.call_tool + + patcher = pydantic_ai_patch.PydanticAIPatch(_RecordingInterceptor()) + assert patcher.apply() is True + assert FakeAbstractToolset.call_tool is not original_call_tool + + patcher.revert() + assert FakeAbstractToolset.call_tool is original_call_tool + assert getattr(FakeAbstractToolset, pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is False + + +def test_apply_false_when_no_known_tool_hook_exists(monkeypatch: pytest.MonkeyPatch) -> None: + """Pydantic AI present but exposing neither ``Tool._run`` nor + ``AbstractToolset.call_tool`` must no-op (return False), never raise. + """ + + class FakeTool: + name = "fake_tool" # no `_run` + + class FakeAbstractToolset: + pass # no `call_tool` + + def fake_import_module(module_name: str) -> object: + if module_name == "pydantic_ai.tools": + return SimpleNamespace(Tool=FakeTool) + if module_name == "pydantic_ai.toolsets": + return SimpleNamespace(AbstractToolset=FakeAbstractToolset) + raise ImportError(module_name) + + monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", fake_import_module) + + patcher = pydantic_ai_patch.PydanticAIPatch(_RecordingInterceptor()) + assert patcher.apply() is False + assert pydantic_ai_patch._get_process_agent_id() is None