diff --git a/src/agents/agent.py b/src/agents/agent.py index b67a12c0d..5c6037532 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -221,6 +221,119 @@ class Agent(AgentBase, Generic[TContext]): """Whether to reset the tool choice to the default value after a tool has been called. Defaults to True. This ensures that the agent doesn't enter an infinite loop of tool usage.""" + def __post_init__(self): + from typing import get_origin + + if not isinstance(self.name, str): + raise TypeError(f"Agent name must be a string, got {type(self.name).__name__}") + + if self.handoff_description is not None and not isinstance(self.handoff_description, str): + raise TypeError( + f"Agent handoff_description must be a string or None, " + f"got {type(self.handoff_description).__name__}" + ) + + if not isinstance(self.tools, list): + raise TypeError(f"Agent tools must be a list, got {type(self.tools).__name__}") + + if not isinstance(self.mcp_servers, list): + raise TypeError( + f"Agent mcp_servers must be a list, got {type(self.mcp_servers).__name__}" + ) + + if not isinstance(self.mcp_config, dict): + raise TypeError( + f"Agent mcp_config must be a dict, got {type(self.mcp_config).__name__}" + ) + + if ( + self.instructions is not None + and not isinstance(self.instructions, str) + and not callable(self.instructions) + ): + raise TypeError( + f"Agent instructions must be a string, callable, or None, " + f"got {type(self.instructions).__name__}" + ) + + if ( + self.prompt is not None + and not callable(self.prompt) + and not hasattr(self.prompt, "get") + ): + raise TypeError( + f"Agent prompt must be a Prompt, DynamicPromptFunction, or None, " + f"got {type(self.prompt).__name__}" + ) + + if not isinstance(self.handoffs, list): + raise TypeError(f"Agent handoffs must be a list, got {type(self.handoffs).__name__}") + + if self.model is not None and not isinstance(self.model, str): + from .models.interface import Model + + if not isinstance(self.model, Model): + raise TypeError( + f"Agent model must be a string, Model, or None, got {type(self.model).__name__}" + ) + + if not isinstance(self.model_settings, ModelSettings): + raise TypeError( + f"Agent model_settings must be a ModelSettings instance, " + f"got {type(self.model_settings).__name__}" + ) + + if not isinstance(self.input_guardrails, list): + raise TypeError( + f"Agent input_guardrails must be a list, got {type(self.input_guardrails).__name__}" + ) + + if not isinstance(self.output_guardrails, list): + raise TypeError( + f"Agent output_guardrails must be a list, " + f"got {type(self.output_guardrails).__name__}" + ) + + if self.output_type is not None: + from .agent_output import AgentOutputSchemaBase + + if not ( + isinstance(self.output_type, (type, AgentOutputSchemaBase)) + or get_origin(self.output_type) is not None + ): + raise TypeError( + f"Agent output_type must be a type, AgentOutputSchemaBase, or None, " + f"got {type(self.output_type).__name__}" + ) + + if self.hooks is not None: + from .lifecycle import AgentHooksBase + + if not isinstance(self.hooks, AgentHooksBase): + raise TypeError( + f"Agent hooks must be an AgentHooks instance or None, " + f"got {type(self.hooks).__name__}" + ) + + if ( + not ( + isinstance(self.tool_use_behavior, str) + and self.tool_use_behavior in ["run_llm_again", "stop_on_first_tool"] + ) + and not isinstance(self.tool_use_behavior, dict) + and not callable(self.tool_use_behavior) + ): + raise TypeError( + f"Agent tool_use_behavior must be 'run_llm_again', 'stop_on_first_tool', " + f"StopAtTools dict, or callable, got {type(self.tool_use_behavior).__name__}" + ) + + if not isinstance(self.reset_tool_choice, bool): + raise TypeError( + f"Agent reset_tool_choice must be a boolean, " + f"got {type(self.reset_tool_choice).__name__}" + ) + def clone(self, **kwargs: Any) -> Agent[TContext]: """Make a copy of the agent, with the given arguments changed. For example, you could do: ``` diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py index a985fd60d..5b633b70b 100644 --- a/tests/test_agent_config.py +++ b/tests/test_agent_config.py @@ -2,6 +2,8 @@ from pydantic import BaseModel from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, handoff +from agents.lifecycle import AgentHooksBase +from agents.model_settings import ModelSettings from agents.run import AgentRunner @@ -167,3 +169,58 @@ async def test_agent_final_output(): assert schema.is_strict_json_schema() is True assert schema.json_schema() is not None assert not schema.is_plain_text() + + +class TestAgentValidation: + """Essential validation tests for Agent __post_init__""" + + def test_name_validation_critical_cases(self): + """Test name validation - the original issue that started this PR""" + # This was the original failing case that caused JSON serialization errors + with pytest.raises(TypeError, match="Agent name must be a string, got int"): + Agent(name=1) # type: ignore + + with pytest.raises(TypeError, match="Agent name must be a string, got NoneType"): + Agent(name=None) # type: ignore + + def test_tool_use_behavior_dict_validation(self): + """Test tool_use_behavior accepts StopAtTools dict - fixes existing test failures""" + # This test ensures the existing failing tests now pass + Agent(name="test", tool_use_behavior={"stop_at_tool_names": ["tool1"]}) + + # Invalid cases that should fail + with pytest.raises(TypeError, match="Agent tool_use_behavior must be"): + Agent(name="test", tool_use_behavior=123) # type: ignore + + def test_hooks_validation_python39_compatibility(self): + """Test hooks validation works with Python 3.9 - fixes generic type issues""" + + class MockHooks(AgentHooksBase): + pass + + # Valid case + Agent(name="test", hooks=MockHooks()) # type: ignore + + # Invalid case + with pytest.raises(TypeError, match="Agent hooks must be an AgentHooks instance"): + Agent(name="test", hooks="invalid") # type: ignore + + def test_list_field_validation(self): + """Test critical list fields that commonly get wrong types""" + # These are the most common mistakes users make + with pytest.raises(TypeError, match="Agent tools must be a list"): + Agent(name="test", tools="not_a_list") # type: ignore + + with pytest.raises(TypeError, match="Agent handoffs must be a list"): + Agent(name="test", handoffs="not_a_list") # type: ignore + + def test_model_settings_validation(self): + """Test model_settings validation - prevents runtime errors""" + # Valid case + Agent(name="test", model_settings=ModelSettings()) + + # Invalid case that could cause runtime issues + with pytest.raises( + TypeError, match="Agent model_settings must be a ModelSettings instance" + ): + Agent(name="test", model_settings={}) # type: ignore