Skip to content

Commit 8997abc

Browse files
committed
Add helper function
1 parent 993f13c commit 8997abc

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

dbostest.sqlite

184 KB
Binary file not shown.

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,14 @@ def __init__(
359359
self._entered_count = 0
360360
self._exit_stack = None
361361

362+
def _get_instructions_literal_and_functions(
363+
self,
364+
) -> tuple[str | None, list[_system_prompt.SystemPromptRunner[AgentDepsT]]]:
365+
instructions, instructions_functions = self._instructions, self._instructions_functions
366+
if override_instructions := self._override_instructions.get():
367+
instructions, instructions_functions = self._instructions_literal_and_functions(override_instructions.value)
368+
return instructions, instructions_functions
369+
362370
def _instructions_literal_and_functions(
363371
self,
364372
instructions: InstructionsInput,
@@ -597,10 +605,7 @@ async def main():
597605
usage_limits = usage_limits or _usage.UsageLimits()
598606

599607
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
600-
literal, functions = self._instructions, self._instructions_functions
601-
if override := self._override_instructions.get():
602-
literal, functions = self._instructions_literal_and_functions(override.value)
603-
608+
literal, functions = self._get_instructions_literal_and_functions()
604609
parts = [
605610
literal,
606611
*[await func.run(run_context) for func in functions],
@@ -642,12 +647,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
642647
instrumentation_settings=instrumentation_settings,
643648
)
644649

645-
instructions_for_node, instructions_functions_for_node = self._instructions, self._instructions_functions
646-
if override_instructions := self._override_instructions.get():
647-
instructions_for_node, instructions_functions_for_node = self._instructions_literal_and_functions(
648-
override_instructions.value
649-
)
650-
650+
instructions_for_node, instructions_functions_for_node = self._get_instructions_literal_and_functions()
651651
start_node = _agent_graph.UserPromptNode[AgentDepsT](
652652
user_prompt=user_prompt,
653653
deferred_tool_results=deferred_tool_results,
@@ -771,9 +771,9 @@ def override(
771771
tools_token = None
772772

773773
if _utils.is_set(instructions):
774-
ins_token = self._override_instructions.set(_utils.Some(instructions))
774+
instructions_token = self._override_instructions.set(_utils.Some(instructions))
775775
else:
776-
ins_token = None
776+
instructions_token = None
777777

778778
try:
779779
yield
@@ -786,8 +786,8 @@ def override(
786786
self._override_toolsets.reset(toolsets_token)
787787
if tools_token is not None:
788788
self._override_tools.reset(tools_token)
789-
if ins_token is not None:
790-
self._override_instructions.reset(ins_token)
789+
if instructions_token is not None:
790+
self._override_instructions.reset(instructions_token)
791791

792792
@overload
793793
def instructions(

tests/test_override_instructions.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ModelMessage,
1111
ModelRequest,
1212
ModelRequestPart,
13+
ModelResponse,
1314
SystemPromptPart,
1415
)
1516
from pydantic_ai.models.test import TestModel
@@ -29,6 +30,20 @@ def _system_prompt_texts(parts: Sequence[ModelRequestPart]) -> list[str]:
2930
return [p.content for p in parts if isinstance(p, SystemPromptPart)]
3031

3132

33+
def test_first_request_skips_non_requests():
34+
"""Helper ignores non-request messages until it finds a request."""
35+
response = ModelResponse(parts=())
36+
request = ModelRequest(parts=())
37+
assert _first_request([response, request]) is request
38+
39+
40+
def test_first_request_raises_without_model_request():
41+
"""Helper raises when no model request is present."""
42+
response = ModelResponse(parts=())
43+
with pytest.raises(AssertionError, match='no ModelRequest found'):
44+
_first_request([response])
45+
46+
3247
def test_override_instructions_basic():
3348
"""Test that override can override instructions."""
3449
agent = Agent('test')
@@ -144,6 +159,18 @@ def override_fn() -> str:
144159
assert 'BASE' not in req.instructions
145160

146161

162+
def test_override_instructions_ignores_unknown_types():
163+
"""Override ignores instruction entries it does not understand."""
164+
agent = Agent('test')
165+
166+
with agent.override(instructions=['ONLY_LITERAL', object()]):
167+
with capture_run_messages() as messages:
168+
agent.run_sync('Hi', model=TestModel(custom_output_text='ok'))
169+
170+
req = _first_request(messages)
171+
assert req.instructions == 'ONLY_LITERAL'
172+
173+
147174
@pytest.mark.anyio
148175
async def test_override_concurrent_isolation():
149176
"""Test that concurrent overrides are isolated from each other."""

0 commit comments

Comments
 (0)