Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions agent_assembly/adapters/google_adk/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ class GoogleADKPatch:
process_agent_id: str | None = None

def apply(self) -> bool:
"""Apply patch wiring and return whether Google ADK is available."""
"""Apply patch wiring and return whether Google ADK is available.

Patches ``BaseTool.run_async`` and every concrete tool class that
overrides ``run_async`` (e.g. ``FunctionTool``), so interception still
runs for ADK 1.x tools whose subclass shadows the base method.
"""
set_process_agent_id(self.process_agent_id)
tool_cls = _load_google_adk_base_tool_class()
if tool_cls is None:
return False
_apply_tool_run_async_patch(tool_cls, self.callback_handler)
for concrete_cls in _load_google_adk_concrete_tool_classes(tool_cls):
_apply_tool_run_async_patch(concrete_cls, self.callback_handler)
agent_cls = _load_google_adk_base_agent_class()
if agent_cls is not None:
_apply_agent_run_async_patch(agent_cls, self.process_agent_id)
Expand All @@ -51,6 +58,8 @@ def revert(self) -> None:
_revert_agent_run_async_patch(agent_cls)
tool_cls = _load_google_adk_base_tool_class()
if tool_cls is not None:
for concrete_cls in _load_google_adk_concrete_tool_classes(tool_cls):
_revert_tool_run_async_patch(concrete_cls)
_revert_tool_run_async_patch(tool_cls)
set_process_agent_id(None)
return None
Expand All @@ -68,6 +77,32 @@ def _load_google_adk_base_tool_class() -> type[Any] | None:
return None


def _load_google_adk_concrete_tool_classes(base_tool_cls: type[Any]) -> list[type[Any]]:
"""Return concrete ADK tool classes that OVERRIDE ``run_async``.

Concrete ADK 1.x tools (e.g. ``FunctionTool``) define their own
``run_async`` on the subclass, so a patch on ``BaseTool.run_async`` never
runs for them. Discover such classes in ``google.adk.tools`` so they can be
patched directly.
"""
try:
module = importlib.import_module("google.adk.tools")
except ImportError:
return []

concrete: list[type[Any]] = []
for attr_value in vars(module).values():
if not isinstance(attr_value, type):
continue
if attr_value is base_tool_cls:
continue
if not issubclass(attr_value, base_tool_cls):
continue
if "run_async" in vars(attr_value):
concrete.append(attr_value)
return concrete


def _load_google_adk_base_agent_class() -> type[Any] | None:
try:
module = importlib.import_module("google.adk.agents")
Expand Down Expand Up @@ -124,7 +159,9 @@ def _revert_agent_run_async_patch(agent_cls: type[Any]) -> None:


def _apply_tool_run_async_patch(tool_cls: type[Any], callback_handler: Any) -> None:
if getattr(tool_cls, _TOOLS_PATCHED_FLAG, False):
# Check the class's OWN dict, not inherited state β€” a concrete subclass that
# overrides run_async must be patched even when its base is already patched.
if vars(tool_cls).get(_TOOLS_PATCHED_FLAG, False):
return None

original_run_async = tool_cls.run_async
Expand Down Expand Up @@ -190,16 +227,17 @@ async def patched_run_async(self: Any, *, args: Any, tool_context: Any, **kwargs


def _revert_tool_run_async_patch(tool_cls: type[Any]) -> None:
if not getattr(tool_cls, _TOOLS_PATCHED_FLAG, False):
# Inspect OWN dict so reverting one class never acts on inherited state.
if not vars(tool_cls).get(_TOOLS_PATCHED_FLAG, False):
return None

original_run_async = getattr(tool_cls, _ORIGINAL_TOOL_RUN_ASYNC, None)
original_run_async = vars(tool_cls).get(_ORIGINAL_TOOL_RUN_ASYNC)
if callable(original_run_async):
tool_cls.run_async = original_run_async

if hasattr(tool_cls, _ORIGINAL_TOOL_RUN_ASYNC):
if _ORIGINAL_TOOL_RUN_ASYNC in vars(tool_cls):
delattr(tool_cls, _ORIGINAL_TOOL_RUN_ASYNC)
if hasattr(tool_cls, _TOOLS_PATCHED_FLAG):
if _TOOLS_PATCHED_FLAG in vars(tool_cls):
delattr(tool_cls, _TOOLS_PATCHED_FLAG)
return None

Expand Down
138 changes: 131 additions & 7 deletions agent_assembly/adapters/pydantic_ai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from agent_assembly.core.spawn import _SPAWN_CTX, SpawnContext, spawn_context_scope

_ORIGINAL_TOOL_RUN = "_agent_assembly_original_pydantic_ai_tool_run"
_ORIGINAL_TOOLSET_CALL_TOOL = "_agent_assembly_original_pydantic_ai_toolset_call_tool"
_TOOLS_PATCHED_FLAG = "_agent_assembly_pydantic_ai_tools_patched"
_ORIGINAL_AGENT_RUN = "_agent_assembly_original_pydantic_ai_agent_run"
_ORIGINAL_AGENT_RUN_SYNC = "_agent_assembly_original_pydantic_ai_agent_run_sync"
Expand All @@ -34,12 +35,28 @@
process_agent_id: str | None = None

def apply(self) -> bool:
"""Apply patch wiring and return whether Pydantic AI is available."""
"""Apply patch wiring and return whether a tool hook was installed.

Detects the tool-execution hook across Pydantic AI versions: the
``Tool._run`` hook on <0.3.0 and the ``AbstractToolset.call_tool``
hook on >=0.3.0. When neither hook point exists, this is a no-op that
returns ``False`` instead of raising ``AttributeError``.
"""
set_process_agent_id(self.process_agent_id)

tool_hooked = False
tool_cls = _load_pydantic_ai_tool_class()
if tool_cls is None:
if tool_cls is not None:
tool_hooked = _apply_tool_run_patch(tool_cls, self.callback_handler)
if not tool_hooked:
toolset_cls = _load_pydantic_ai_toolset_class()
if toolset_cls is not None:
tool_hooked = _apply_toolset_call_tool_patch(toolset_cls, self.callback_handler)

if not tool_hooked:
set_process_agent_id(None)
return False
_apply_tool_run_patch(tool_cls, self.callback_handler)

agent_cls = _load_pydantic_ai_agent_class()
if agent_cls is not None:
_apply_agent_run_patch(agent_cls, self.process_agent_id)
Expand All @@ -53,6 +70,9 @@
tool_cls = _load_pydantic_ai_tool_class()
if tool_cls is not None:
_revert_tool_run_patch(tool_cls)
toolset_cls = _load_pydantic_ai_toolset_class()
if toolset_cls is not None:
_revert_toolset_call_tool_patch(toolset_cls)
set_process_agent_id(None)
return None

Expand Down Expand Up @@ -96,6 +116,23 @@
return None


def _load_pydantic_ai_toolset_class() -> type[Any] | None:
"""Load ``AbstractToolset`` β€” the >=0.3.0 tool-execution hook point.

In Pydantic AI >=0.3.0 tool execution routes through
``AbstractToolset.call_tool`` rather than ``Tool._run``.
"""
try:
module = importlib.import_module("pydantic_ai.toolsets")
except ImportError:
return None

toolset_cls = getattr(module, "AbstractToolset", None)
if isinstance(toolset_cls, type):
return toolset_cls
return None


def _load_pydantic_ai_agent_class() -> type[Any] | None:
try:
module = importlib.import_module("pydantic_ai")
Expand Down Expand Up @@ -170,11 +207,14 @@
return None


def _apply_tool_run_patch(tool_cls: type[Any], callback_handler: Any) -> None:
def _apply_tool_run_patch(tool_cls: type[Any], callback_handler: Any) -> bool:
"""Patch ``Tool._run`` (the <0.3.0 hook); no-op if it is unavailable."""
if getattr(tool_cls, _TOOLS_PATCHED_FLAG, False):
return None
return True

original_run = tool_cls._run
original_run = getattr(tool_cls, "_run", None)
if not callable(original_run):
return False

@wraps(original_run)
async def patched_run(self: Any, ctx: Any, args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -233,7 +273,7 @@
setattr(tool_cls, _ORIGINAL_TOOL_RUN, original_run)
tool_cls._run = patched_run
setattr(tool_cls, _TOOLS_PATCHED_FLAG, True)
return None
return True


def _revert_tool_run_patch(tool_cls: type[Any]) -> None:
Expand All @@ -251,6 +291,90 @@
return None


def _apply_toolset_call_tool_patch(toolset_cls: type[Any], callback_handler: Any) -> bool:
"""Patch ``AbstractToolset.call_tool`` (the >=0.3.0 hook); no-op if absent."""
if getattr(toolset_cls, _TOOLS_PATCHED_FLAG, False):
return True

original_call_tool = getattr(toolset_cls, "call_tool", None)
if not callable(original_call_tool):
return False

@wraps(original_call_tool)
async def patched_call_tool(self: Any, name: Any, tool_args: Any, ctx: Any, tool: Any, **kwargs: Any) -> Any:
tool_name = str(name)
serialized_args = _serialize_tool_args(tool_args)
agent_id = _resolve_agent_id(ctx)
run_id = _resolve_run_id(ctx)

decision = await _invoke_async_tool_check(
callback_handler,
tool_name=tool_name,
tool_args=serialized_args,
agent_id=agent_id,
run_id=run_id,
)
status, reason = _normalize_decision(decision)
is_pending_flow = False
if status == "pending":
is_pending_flow = True
timeout_seconds = _get_pending_tool_approval_timeout_seconds(callback_handler)
final_decision = await _wait_for_async_tool_approval(
callback_handler,
tool_name=tool_name,
timeout_seconds=timeout_seconds,
tool_args=serialized_args,
agent_id=agent_id,
run_id=run_id,
)
status, reason = _normalize_decision(final_decision)

if status == "deny":
if is_pending_flow:
raise _build_pending_rejected_error(tool_name, reason)

Check warning on line 334 in agent_assembly/adapters/pydantic_ai/patch.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Replace this generic exception class with a more specific one.

See more on https://sonarcloud.io/project/issues?id=AI-agent-assembly_python-sdk&issues=AZ7GCGvYlFS2SonG5K3z&open=AZ7GCGvYlFS2SonG5K3z&pullRequest=119
raise _build_denied_error(tool_name, reason)

Check warning on line 335 in agent_assembly/adapters/pydantic_ai/patch.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Replace this generic exception class with a more specific one.

See more on https://sonarcloud.io/project/issues?id=AI-agent-assembly_python-sdk&issues=AZ7GCGvYlFS2SonG5K30&open=AZ7GCGvYlFS2SonG5K30&pullRequest=119

spawn_ctx = SpawnContext(
parent_agent_id=agent_id or "",
depth=_current_spawn_depth(),
spawned_by_tool=tool_name,
delegation_reason=f"tool:{tool_name}",
)
with spawn_context_scope(spawn_ctx):
result = original_call_tool(self, name, tool_args, ctx, tool, **kwargs)
if inspect.isawaitable(result):
result = await result

await _record_async_tool_result(
callback_handler,
tool_name=tool_name,
result=result,
agent_id=agent_id,
run_id=run_id,
)
return result

setattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL, original_call_tool)
toolset_cls.call_tool = patched_call_tool
setattr(toolset_cls, _TOOLS_PATCHED_FLAG, True)
return True


def _revert_toolset_call_tool_patch(toolset_cls: type[Any]) -> None:
if not getattr(toolset_cls, _TOOLS_PATCHED_FLAG, False):
return None

original_call_tool = getattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL, None)
if callable(original_call_tool):
toolset_cls.call_tool = original_call_tool

if hasattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL):
delattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL)
if hasattr(toolset_cls, _TOOLS_PATCHED_FLAG):
delattr(toolset_cls, _TOOLS_PATCHED_FLAG)
return None


def set_process_agent_id(agent_id: str | None) -> None:
global _PROCESS_AGENT_ID
_PROCESS_AGENT_ID = agent_id
Expand Down
99 changes: 99 additions & 0 deletions test/unit/adapters/google_adk/test_google_adk_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,102 @@ def fake_import_module(module_name: str) -> object:

patcher.revert()
assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False


def _install_fake_google_adk_with_function_tool(
monkeypatch: pytest.MonkeyPatch,
) -> tuple[type[Any], type[Any]]:
"""Model ADK 1.x: a concrete ``FunctionTool`` subclass that OVERRIDES
``run_async``, so a patch on ``BaseTool.run_async`` alone never runs for it.
"""

class FakeBaseTool:
name = "base_tool"

async def run_async(self, *, args: Any, tool_context: Any, **kwargs: Any) -> dict[str, object]:
del kwargs
return {"who": "base", "args": args, "tool_context": tool_context}

class FakeFunctionTool(FakeBaseTool):
name = "function_tool"

# Concrete subclass overrides run_async (its own __dict__ entry).
async def run_async(self, *, args: Any, tool_context: Any, **kwargs: Any) -> dict[str, object]:
del kwargs
return {"who": "function", "args": args, "tool_context": tool_context}

fake_module = SimpleNamespace(BaseTool=FakeBaseTool, FunctionTool=FakeFunctionTool)

def fake_import_module(module_name: str) -> object:
if module_name == "google.adk.tools":
return fake_module
raise ImportError(module_name)

monkeypatch.setattr(google_adk_patch.importlib, "import_module", fake_import_module)
return FakeBaseTool, FakeFunctionTool


def test_load_concrete_tool_classes_finds_run_async_overrides(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeBaseTool, FakeFunctionTool = _install_fake_google_adk_with_function_tool(monkeypatch)

concrete = google_adk_patch._load_google_adk_concrete_tool_classes(FakeBaseTool)

# The base class itself is excluded; the overriding subclass is included.
assert FakeFunctionTool in concrete
assert FakeBaseTool not in concrete


@pytest.mark.asyncio
async def test_apply_patches_concrete_function_tool_override(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeBaseTool, FakeFunctionTool = _install_fake_google_adk_with_function_tool(monkeypatch)

patcher = google_adk_patch.GoogleADKPatch(_AllowInterceptor())
assert patcher.apply() is True

# Both the base and the concrete override are patched.
assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is True
assert getattr(FakeFunctionTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is True


@pytest.mark.asyncio
async def test_function_tool_override_is_intercepted_on_deny(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_, FakeFunctionTool = _install_fake_google_adk_with_function_tool(monkeypatch)

class Interceptor:
async def check_tool_start(self, **kwargs: object) -> dict[str, str]:
del kwargs
return {"status": "deny", "reason": "blocked subclass"}

patcher = google_adk_patch.GoogleADKPatch(Interceptor())
assert patcher.apply() is True

tool = FakeFunctionTool()
tool_context = SimpleNamespace(invocation_context=None)
# Governance runs on the SUBCLASS run_async, not just the base.
with pytest.raises(PolicyViolationError, match="blocked by governance policy: blocked subclass"):
await tool.run_async(args={"step": 1}, tool_context=tool_context)


@pytest.mark.asyncio
async def test_revert_restores_concrete_function_tool_override(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeBaseTool, FakeFunctionTool = _install_fake_google_adk_with_function_tool(monkeypatch)
original_base = FakeBaseTool.run_async
original_function = FakeFunctionTool.run_async

patcher = google_adk_patch.GoogleADKPatch(_AllowInterceptor())
assert patcher.apply() is True
assert FakeFunctionTool.run_async is not original_function

patcher.revert()
assert FakeBaseTool.run_async is original_base
assert FakeFunctionTool.run_async is original_function
assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False
assert getattr(FakeFunctionTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False
Loading