Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions agent_assembly/adapters/pydantic_ai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
``Tool._run`` hook on <0.3.0 and the ``AbstractToolset.call_tool``
hook on >=0.3.0. When neither hook point exists, this is a no-op that
returns ``False`` instead of raising ``AttributeError``.

On >=0.3.0 the abstract base patch is shadowed by concrete toolsets
(e.g. ``FunctionToolset``) that override ``call_tool`` without calling
``super().call_tool(...)``. Such concrete classes are patched directly
in addition to the base so function-tool governance still fires.
"""
set_process_agent_id(self.process_agent_id)

Expand All @@ -52,6 +57,8 @@
toolset_cls = _load_pydantic_ai_toolset_class()
if toolset_cls is not None:
tool_hooked = _apply_toolset_call_tool_patch(toolset_cls, self.callback_handler)
for concrete_cls in _load_pydantic_ai_concrete_toolset_classes(toolset_cls):
_apply_toolset_call_tool_patch(concrete_cls, self.callback_handler)

if not tool_hooked:
set_process_agent_id(None)
Expand All @@ -72,6 +79,8 @@
_revert_tool_run_patch(tool_cls)
toolset_cls = _load_pydantic_ai_toolset_class()
if toolset_cls is not None:
for concrete_cls in _load_pydantic_ai_concrete_toolset_classes(toolset_cls):
_revert_toolset_call_tool_patch(concrete_cls)
_revert_toolset_call_tool_patch(toolset_cls)
set_process_agent_id(None)
return None
Expand Down Expand Up @@ -133,6 +142,54 @@
return None


def _load_pydantic_ai_concrete_toolset_classes(base_toolset_cls: type[Any]) -> list[type[Any]]:

Check failure on line 145 in agent_assembly/adapters/pydantic_ai/patch.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 16 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=AI-agent-assembly_python-sdk&issues=AZ7GO2X1EhMlwP9QFhw3&open=AZ7GO2X1EhMlwP9QFhw3&pullRequest=120
"""Return concrete ``AbstractToolset`` subclasses that OVERRIDE ``call_tool``.

Function tools (``@agent.tool_plain`` / ``@agent.tool``) execute through
``pydantic_ai.toolsets.function.FunctionToolset.call_tool``, which overrides
the abstract base WITHOUT calling ``super().call_tool(...)``. A patch on
``AbstractToolset.call_tool`` is therefore shadowed and never runs for them.

Load ``FunctionToolset`` explicitly, then generically discover any other
concrete toolset in ``pydantic_ai.toolsets`` that defines its own
``call_tool`` (in ``vars(cls)``) so other own-``call_tool`` toolsets are
covered. Stays fail-soft when Pydantic AI is absent.
"""
discovered: list[type[Any]] = []
seen: set[int] = set()

def _consider(candidate: Any) -> None:
if not isinstance(candidate, type):
return
if candidate is base_toolset_cls:
return
if id(candidate) in seen:
return
if not issubclass(candidate, base_toolset_cls):
return
if "call_tool" not in vars(candidate):
return
seen.add(id(candidate))
discovered.append(candidate)

try:
function_module = importlib.import_module("pydantic_ai.toolsets.function")
except ImportError:
function_module = None
if function_module is not None:
_consider(getattr(function_module, "FunctionToolset", None))

try:
toolsets_module = importlib.import_module("pydantic_ai.toolsets")
except ImportError:
toolsets_module = None
if toolsets_module is not None:
for attr_value in vars(toolsets_module).values():
_consider(attr_value)

return discovered


def _load_pydantic_ai_agent_class() -> type[Any] | None:
try:
module = importlib.import_module("pydantic_ai")
Expand Down Expand Up @@ -292,8 +349,14 @@


def _apply_toolset_call_tool_patch(toolset_cls: type[Any], callback_handler: Any) -> bool:
"""Patch ``AbstractToolset.call_tool`` (the >=0.3.0 hook); no-op if absent."""
if getattr(toolset_cls, _TOOLS_PATCHED_FLAG, False):
"""Patch a toolset class's ``call_tool`` (the >=0.3.0 hook); no-op if absent.

Applies to ``AbstractToolset`` and to each concrete subclass that overrides
``call_tool``. The patched-flag is checked on the class's OWN dict, not
inherited state β€” a concrete subclass whose base is already patched must
still be patched directly, otherwise its override shadows governance.
"""
if vars(toolset_cls).get(_TOOLS_PATCHED_FLAG, False):
return True

original_call_tool = getattr(toolset_cls, "call_tool", None)
Expand Down Expand Up @@ -361,16 +424,18 @@


def _revert_toolset_call_tool_patch(toolset_cls: type[Any]) -> None:
if not getattr(toolset_cls, _TOOLS_PATCHED_FLAG, False):
# Check the class's OWN dict so reverting the base never clears a concrete
# subclass's flag (and vice versa); each class restores its own call_tool.
if not vars(toolset_cls).get(_TOOLS_PATCHED_FLAG, False):
return None

original_call_tool = getattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL, None)
original_call_tool = vars(toolset_cls).get(_ORIGINAL_TOOLSET_CALL_TOOL, None)
if callable(original_call_tool):
toolset_cls.call_tool = original_call_tool

if hasattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL):
if _ORIGINAL_TOOLSET_CALL_TOOL in vars(toolset_cls):
delattr(toolset_cls, _ORIGINAL_TOOLSET_CALL_TOOL)
if hasattr(toolset_cls, _TOOLS_PATCHED_FLAG):
if _TOOLS_PATCHED_FLAG in vars(toolset_cls):
delattr(toolset_cls, _TOOLS_PATCHED_FLAG)
return None

Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ test = [
"pytest-rerunfailures>=14.0,<17",
"pytest-asyncio>=0.23.0,<2",
"pytest-benchmark>=4.0.0,<6",
# AAASM-2943: dev/test-only (NOT a runtime dependency). Installing the
# framework lets the `importorskip`-guarded Pydantic AI integration tests
# run in CI (the `dev` group includes `test`, and the integration-test job
# installs `dev`), so the function-tool governance regression is actually
# exercised instead of skipped.
"pydantic-ai>=0.3.0",
]
pre-commit-ci = [
"pre-commit>=3.5.0,<5",
Expand Down
48 changes: 47 additions & 1 deletion test/integration/test_pydantic_ai_interception_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,52 @@ async def check_tool_start(self, **kwargs: object) -> dict[str, str]:
del kwargs
return {"status": "allow"}

patcher = pydantic_ai_patch.PydanticAIPatch(Interceptor())
try:
assert patcher.apply() is True
# The installed version determines which hook is set: ``Tool._run`` on
# <0.3.0, else ``AbstractToolset.call_tool`` on >=0.3.0. Assert that the
# version-appropriate hook carries the patched flag on its own class.
toolset_cls = pydantic_ai_patch._load_pydantic_ai_toolset_class()
tool_flagged = vars(tool_cls).get(pydantic_ai_patch._TOOLS_PATCHED_FLAG, False)
toolset_flagged = toolset_cls is not None and vars(toolset_cls).get(
pydantic_ai_patch._TOOLS_PATCHED_FLAG, False
)
assert tool_flagged or toolset_flagged
finally:
patcher.revert()


@pytest.mark.integration
def test_pydantic_ai_real_function_tool_deny_raises_after_apply() -> None:
"""End-to-end proof against the real library (AAASM-2943).

Function tools (``@agent.tool_plain``) execute through
``FunctionToolset.call_tool``, which overrides ``AbstractToolset.call_tool``
WITHOUT ``super()``. Before the concrete-class patch, a denied function tool
ran without raising. After ``apply()`` patches the concrete class too, the
deny path raises ``PolicyViolationError``.
"""
pytest.importorskip("pydantic_ai")
from pydantic_ai import Agent
from pydantic_ai.models.test import TestModel

class Interceptor:
async def check_tool_start(self, **kwargs: object) -> dict[str, str]:
if kwargs.get("tool_name") == "blocked_tool":
return {"status": "deny", "reason": "blocked by policy"}
return {"status": "allow"}

patcher = pydantic_ai_patch.PydanticAIPatch(Interceptor())
assert patcher.apply() is True
assert getattr(tool_cls, pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is True
try:
agent = Agent(TestModel(call_tools=["blocked_tool"]))

@agent.tool_plain
def blocked_tool() -> str:
return "ran"

with pytest.raises(PolicyViolationError, match="blocked by governance policy: blocked by policy"):
agent.run_sync("invoke the tool")
finally:
patcher.revert()
127 changes: 127 additions & 0 deletions test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,133 @@ def test_v030_revert_restores_toolset_call_tool(monkeypatch: pytest.MonkeyPatch)
assert getattr(FakeAbstractToolset, pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is False


def _install_fake_pydantic_ai_function_toolset_modules(
monkeypatch: pytest.MonkeyPatch,
) -> tuple[type[Any], type[Any]]:
"""Model the >=0.3.0 shadowing bug: ``FunctionToolset`` subclasses
``AbstractToolset`` and overrides ``call_tool`` WITHOUT calling ``super()``.

A patch on ``AbstractToolset.call_tool`` is shadowed for function tools, so
the concrete class must be discovered and patched directly. Returns the
``(AbstractToolset, FunctionToolset)`` fakes.
"""

class FakeTool:
name = "fake_tool" # no `_run` β€” mirrors the >=0.3.0 restructure

class FakeAbstractToolset:
async def call_tool(self, name: Any, tool_args: Any, ctx: Any, tool: Any) -> dict[str, object]:
return {"src": "abstract", "name": name, "tool_args": tool_args, "ctx": ctx, "tool": tool}

class FakeFunctionToolset(FakeAbstractToolset):
# Overrides call_tool WITHOUT super() β€” the shadowing that hides the
# base-class patch from function tools.
async def call_tool(self, name: Any, tool_args: Any, ctx: Any, tool: Any) -> dict[str, object]:
return {"src": "function", "name": name, "tool_args": tool_args, "ctx": ctx, "tool": tool}

fake_tools_module = SimpleNamespace(Tool=FakeTool)
fake_toolsets_module = SimpleNamespace(AbstractToolset=FakeAbstractToolset)
fake_function_module = SimpleNamespace(FunctionToolset=FakeFunctionToolset)

def fake_import_module(module_name: str) -> object:
if module_name == "pydantic_ai.tools":
return fake_tools_module
if module_name == "pydantic_ai.toolsets":
return fake_toolsets_module
if module_name == "pydantic_ai.toolsets.function":
return fake_function_module
raise ImportError(module_name)

monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", fake_import_module)
return FakeAbstractToolset, FakeFunctionToolset


def test_concrete_toolset_discovery_finds_function_toolset(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeAbstractToolset, FakeFunctionToolset = _install_fake_pydantic_ai_function_toolset_modules(monkeypatch)

discovered = pydantic_ai_patch._load_pydantic_ai_concrete_toolset_classes(FakeAbstractToolset)
assert FakeFunctionToolset in discovered
# The abstract base itself is never returned as a "concrete overrider".
assert FakeAbstractToolset not in discovered


@pytest.mark.asyncio
async def test_apply_patches_both_base_and_concrete_function_toolset(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeAbstractToolset, FakeFunctionToolset = _install_fake_pydantic_ai_function_toolset_modules(monkeypatch)

patcher = pydantic_ai_patch.PydanticAIPatch(_RecordingInterceptor())
assert patcher.apply() is True

# Both classes carry their OWN patched flag β€” a patched base must not mask
# the concrete subclass.
assert vars(FakeAbstractToolset).get(pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is True
assert vars(FakeFunctionToolset).get(pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is True


@pytest.mark.asyncio
async def test_denied_tool_raises_when_invoked_via_concrete_call_tool(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_, FakeFunctionToolset = _install_fake_pydantic_ai_function_toolset_modules(monkeypatch)

class Interceptor:
async def check_tool_start(self, **kwargs: object) -> dict[str, str]:
del kwargs
return {"status": "deny", "reason": "blocked function tool"}

patcher = pydantic_ai_patch.PydanticAIPatch(Interceptor())
assert patcher.apply() is True

toolset = FakeFunctionToolset()
ctx = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="agent-fn"), run_id="run-fn")

# Governance fires on the concrete override, not just the (shadowed) base.
with pytest.raises(PolicyViolationError, match="blocked by governance policy: blocked function tool"):
await toolset.call_tool("fake_tool", {"q": "secret"}, ctx, object())


@pytest.mark.asyncio
async def test_revert_restores_concrete_and_base_call_tool(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeAbstractToolset, FakeFunctionToolset = _install_fake_pydantic_ai_function_toolset_modules(monkeypatch)
original_base = FakeAbstractToolset.call_tool
original_concrete = FakeFunctionToolset.call_tool

patcher = pydantic_ai_patch.PydanticAIPatch(_RecordingInterceptor())
assert patcher.apply() is True
assert FakeAbstractToolset.call_tool is not original_base
assert FakeFunctionToolset.call_tool is not original_concrete

patcher.revert()
assert FakeAbstractToolset.call_tool is original_base
assert FakeFunctionToolset.call_tool is original_concrete
assert vars(FakeAbstractToolset).get(pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is False
assert vars(FakeFunctionToolset).get(pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is False

# Revert is idempotent.
patcher.revert()
assert FakeFunctionToolset.call_tool is original_concrete


def test_concrete_toolset_discovery_fail_soft_without_pydantic_ai(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def raise_import_error(module_name: str) -> object:
raise ImportError(module_name)

monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", raise_import_error)

class FakeBase:
pass

assert pydantic_ai_patch._load_pydantic_ai_concrete_toolset_classes(FakeBase) == []


def test_apply_false_when_no_known_tool_hook_exists(monkeypatch: pytest.MonkeyPatch) -> None:
"""Pydantic AI present but exposing neither ``Tool._run`` nor
``AbstractToolset.call_tool`` must no-op (return False), never raise.
Expand Down
Loading
Loading