Skip to content

Commit efaf6be

Browse files
authored
Support contextually overriding agent instructions (#2926)
1 parent aedbcd6 commit efaf6be

File tree

6 files changed

+335
-42
lines changed

6 files changed

+335
-42
lines changed

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from ..toolsets.combined import CombinedToolset
6767
from ..toolsets.function import FunctionToolset
6868
from ..toolsets.prepared import PreparedToolset
69-
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
69+
from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT
7070
from .wrapper import WrapperAgent
7171

7272
if TYPE_CHECKING:
@@ -137,8 +137,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
137137
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
138138
_output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False)
139139
_output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
140-
_instructions: str | None = dataclasses.field(repr=False)
141-
_instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
140+
_instructions: list[str | _system_prompt.SystemPromptFunc[AgentDepsT]] = dataclasses.field(repr=False)
142141
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
143142
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
144143
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
@@ -164,10 +163,7 @@ def __init__(
164163
model: models.Model | models.KnownModelName | str | None = None,
165164
*,
166165
output_type: OutputSpec[OutputDataT] = str,
167-
instructions: str
168-
| _system_prompt.SystemPromptFunc[AgentDepsT]
169-
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
170-
| None = None,
166+
instructions: Instructions[AgentDepsT] = None,
171167
system_prompt: str | Sequence[str] = (),
172168
deps_type: type[AgentDepsT] = NoneType,
173169
name: str | None = None,
@@ -193,10 +189,7 @@ def __init__(
193189
model: models.Model | models.KnownModelName | str | None = None,
194190
*,
195191
output_type: OutputSpec[OutputDataT] = str,
196-
instructions: str
197-
| _system_prompt.SystemPromptFunc[AgentDepsT]
198-
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
199-
| None = None,
192+
instructions: Instructions[AgentDepsT] = None,
200193
system_prompt: str | Sequence[str] = (),
201194
deps_type: type[AgentDepsT] = NoneType,
202195
name: str | None = None,
@@ -220,10 +213,7 @@ def __init__(
220213
model: models.Model | models.KnownModelName | str | None = None,
221214
*,
222215
output_type: OutputSpec[OutputDataT] = str,
223-
instructions: str
224-
| _system_prompt.SystemPromptFunc[AgentDepsT]
225-
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
226-
| None = None,
216+
instructions: Instructions[AgentDepsT] = None,
227217
system_prompt: str | Sequence[str] = (),
228218
deps_type: type[AgentDepsT] = NoneType,
229219
name: str | None = None,
@@ -322,16 +312,7 @@ def __init__(
322312
self._output_schema = _output.OutputSchema[OutputDataT].build(output_type, default_mode=default_output_mode)
323313
self._output_validators = []
324314

325-
self._instructions = ''
326-
self._instructions_functions = []
327-
if isinstance(instructions, str | Callable):
328-
instructions = [instructions]
329-
for instruction in instructions or []:
330-
if isinstance(instruction, str):
331-
self._instructions += instruction + '\n'
332-
else:
333-
self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction))
334-
self._instructions = self._instructions.strip() or None
315+
self._instructions = self._normalize_instructions(instructions)
335316

336317
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
337318
self._system_prompt_functions = []
@@ -371,6 +352,9 @@ def __init__(
371352
self._override_tools: ContextVar[
372353
_utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]]
373354
] = ContextVar('_override_tools', default=None)
355+
self._override_instructions: ContextVar[
356+
_utils.Option[list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]]
357+
] = ContextVar('_override_instructions', default=None)
374358

375359
self._enter_lock = Lock()
376360
self._entered_count = 0
@@ -593,10 +577,12 @@ async def main():
593577
model_settings = merge_model_settings(merged_settings, model_settings)
594578
usage_limits = usage_limits or _usage.UsageLimits()
595579

580+
instructions_literal, instructions_functions = self._get_instructions()
581+
596582
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
597583
parts = [
598-
self._instructions,
599-
*[await func.run(run_context) for func in self._instructions_functions],
584+
instructions_literal,
585+
*[await func.run(run_context) for func in instructions_functions],
600586
]
601587

602588
model_profile = model_used.profile
@@ -634,11 +620,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
634620
get_instructions=get_instructions,
635621
instrumentation_settings=instrumentation_settings,
636622
)
623+
637624
start_node = _agent_graph.UserPromptNode[AgentDepsT](
638625
user_prompt=user_prompt,
639626
deferred_tool_results=deferred_tool_results,
640-
instructions=self._instructions,
641-
instructions_functions=self._instructions_functions,
627+
instructions=instructions_literal,
628+
instructions_functions=instructions_functions,
642629
system_prompts=self._system_prompts,
643630
system_prompt_functions=self._system_prompt_functions,
644631
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
@@ -690,6 +677,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
690677
def _run_span_end_attributes(
691678
self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings
692679
):
680+
literal_instructions, _ = self._get_instructions()
681+
693682
if settings.version == 1:
694683
attrs = {
695684
'all_messages_events': json.dumps(
@@ -702,7 +691,7 @@ def _run_span_end_attributes(
702691
else:
703692
attrs = {
704693
'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(state.message_history)),
705-
**settings.system_instructions_attributes(self._instructions),
694+
**settings.system_instructions_attributes(literal_instructions),
706695
}
707696

708697
return {
@@ -727,8 +716,9 @@ def override(
727716
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
728717
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
729718
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
719+
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
730720
) -> Iterator[None]:
731-
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
721+
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.
732722
733723
This is particularly useful when testing.
734724
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
@@ -738,6 +728,7 @@ def override(
738728
model: The model to use instead of the model passed to the agent run.
739729
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
740730
tools: The tools to use instead of the tools registered with the agent.
731+
instructions: The instructions to use instead of the instructions registered with the agent.
741732
"""
742733
if _utils.is_set(deps):
743734
deps_token = self._override_deps.set(_utils.Some(deps))
@@ -759,6 +750,12 @@ def override(
759750
else:
760751
tools_token = None
761752

753+
if _utils.is_set(instructions):
754+
normalized_instructions = self._normalize_instructions(instructions)
755+
instructions_token = self._override_instructions.set(_utils.Some(normalized_instructions))
756+
else:
757+
instructions_token = None
758+
762759
try:
763760
yield
764761
finally:
@@ -770,6 +767,8 @@ def override(
770767
self._override_toolsets.reset(toolsets_token)
771768
if tools_token is not None:
772769
self._override_tools.reset(tools_token)
770+
if instructions_token is not None:
771+
self._override_instructions.reset(instructions_token)
773772

774773
@overload
775774
def instructions(
@@ -830,12 +829,12 @@ async def async_instructions(ctx: RunContext[str]) -> str:
830829
def decorator(
831830
func_: _system_prompt.SystemPromptFunc[AgentDepsT],
832831
) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
833-
self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_))
832+
self._instructions.append(func_)
834833
return func_
835834

836835
return decorator
837836
else:
838-
self._instructions_functions.append(_system_prompt.SystemPromptRunner(func))
837+
self._instructions.append(func)
839838
return func
840839

841840
@overload
@@ -1276,6 +1275,34 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
12761275
else:
12771276
return deps
12781277

1278+
def _normalize_instructions(
1279+
self,
1280+
instructions: Instructions[AgentDepsT],
1281+
) -> list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]:
1282+
if instructions is None:
1283+
return []
1284+
if isinstance(instructions, str) or callable(instructions):
1285+
return [instructions]
1286+
return list(instructions)
1287+
1288+
def _get_instructions(
1289+
self,
1290+
) -> tuple[str | None, list[_system_prompt.SystemPromptRunner[AgentDepsT]]]:
1291+
override_instructions = self._override_instructions.get()
1292+
instructions = override_instructions.value if override_instructions else self._instructions
1293+
1294+
literal_parts: list[str] = []
1295+
functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []
1296+
1297+
for instruction in instructions:
1298+
if isinstance(instruction, str):
1299+
literal_parts.append(instruction)
1300+
else:
1301+
functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](instruction))
1302+
1303+
literal = '\n'.join(literal_parts).strip() or None
1304+
return literal, functions
1305+
12791306
def _get_toolset(
12801307
self,
12811308
output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .. import (
1616
_agent_graph,
17+
_system_prompt,
1718
_utils,
1819
exceptions,
1920
messages as _messages,
@@ -60,6 +61,14 @@
6061
"""A function that receives agent [`RunContext`][pydantic_ai.tools.RunContext] and an async iterable of events from the model's streaming response and the agent's execution of tools."""
6162

6263

64+
Instructions = (
65+
str
66+
| _system_prompt.SystemPromptFunc[AgentDepsT]
67+
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
68+
| None
69+
)
70+
71+
6372
class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
6473
"""Abstract superclass for [`Agent`][pydantic_ai.agent.Agent], [`WrapperAgent`][pydantic_ai.agent.WrapperAgent], and your own custom agent implementations."""
6574

@@ -681,8 +690,9 @@ def override(
681690
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
682691
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
683692
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
693+
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
684694
) -> Iterator[None]:
685-
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
695+
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.
686696
687697
This is particularly useful when testing.
688698
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
@@ -692,6 +702,7 @@ def override(
692702
model: The model to use instead of the model passed to the agent run.
693703
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
694704
tools: The tools to use instead of the tools registered with the agent.
705+
instructions: The instructions to use instead of the instructions registered with the agent.
695706
"""
696707
raise NotImplementedError
697708
yield

pydantic_ai_slim/pydantic_ai/agent/wrapper.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
ToolFuncEither,
2121
)
2222
from ..toolsets import AbstractToolset
23-
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
23+
from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT
2424

2525

2626
class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT]):
@@ -214,8 +214,9 @@ def override(
214214
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
215215
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
216216
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
217+
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
217218
) -> Iterator[None]:
218-
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
219+
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.
219220
220221
This is particularly useful when testing.
221222
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
@@ -225,6 +226,13 @@ def override(
225226
model: The model to use instead of the model passed to the agent run.
226227
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
227228
tools: The tools to use instead of the tools registered with the agent.
229+
instructions: The instructions to use instead of the instructions registered with the agent.
228230
"""
229-
with self.wrapped.override(deps=deps, model=model, toolsets=toolsets, tools=tools):
231+
with self.wrapped.override(
232+
deps=deps,
233+
model=model,
234+
toolsets=toolsets,
235+
tools=tools,
236+
instructions=instructions,
237+
):
230238
yield

pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
usage as _usage,
1616
)
1717
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
18+
from pydantic_ai.agent.abstract import Instructions
1819
from pydantic_ai.exceptions import UserError
1920
from pydantic_ai.models import Model
2021
from pydantic_ai.output import OutputDataT, OutputSpec
@@ -704,8 +705,9 @@ def override(
704705
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
705706
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
706707
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
708+
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
707709
) -> Iterator[None]:
708-
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
710+
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.
709711
710712
This is particularly useful when testing.
711713
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
@@ -715,11 +717,18 @@ def override(
715717
model: The model to use instead of the model passed to the agent run.
716718
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
717719
tools: The tools to use instead of the tools registered with the agent.
720+
instructions: The instructions to use instead of the instructions registered with the agent.
718721
"""
719722
if _utils.is_set(model) and not isinstance(model, (DBOSModel)):
720723
raise UserError(
721724
'Non-DBOS model cannot be contextually overridden inside a DBOS workflow, it must be set at agent creation time.'
722725
)
723726

724-
with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools):
727+
with super().override(
728+
deps=deps,
729+
model=model,
730+
toolsets=toolsets,
731+
tools=tools,
732+
instructions=instructions,
733+
):
725734
yield

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
models,
2323
usage as _usage,
2424
)
25-
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
25+
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent
26+
from pydantic_ai.agent.abstract import Instructions, RunOutputDataT
2627
from pydantic_ai.exceptions import UserError
2728
from pydantic_ai.models import Model
2829
from pydantic_ai.output import OutputDataT, OutputSpec
@@ -748,8 +749,9 @@ def override(
748749
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
749750
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
750751
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
752+
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
751753
) -> Iterator[None]:
752-
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
754+
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.
753755
754756
This is particularly useful when testing.
755757
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
@@ -759,6 +761,7 @@ def override(
759761
model: The model to use instead of the model passed to the agent run.
760762
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
761763
tools: The tools to use instead of the tools registered with the agent.
764+
instructions: The instructions to use instead of the instructions registered with the agent.
762765
"""
763766
if workflow.in_workflow():
764767
if _utils.is_set(model):
@@ -774,5 +777,11 @@ def override(
774777
'Tools cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.'
775778
)
776779

777-
with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools):
780+
with super().override(
781+
deps=deps,
782+
model=model,
783+
toolsets=toolsets,
784+
tools=tools,
785+
instructions=instructions,
786+
):
778787
yield

0 commit comments

Comments
 (0)