Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 60 additions & 33 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from ..toolsets.combined import CombinedToolset
from ..toolsets.function import FunctionToolset
from ..toolsets.prepared import PreparedToolset
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT
from .wrapper import WrapperAgent

if TYPE_CHECKING:
Expand Down Expand Up @@ -137,8 +137,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
_output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False)
_output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
_instructions: str | None = dataclasses.field(repr=False)
_instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
_instructions: list[str | _system_prompt.SystemPromptFunc[AgentDepsT]] = dataclasses.field(repr=False)
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
Expand All @@ -164,10 +163,7 @@ def __init__(
model: models.Model | models.KnownModelName | str | None = None,
*,
output_type: OutputSpec[OutputDataT] = str,
instructions: str
| _system_prompt.SystemPromptFunc[AgentDepsT]
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
| None = None,
instructions: Instructions[AgentDepsT] = None,
system_prompt: str | Sequence[str] = (),
deps_type: type[AgentDepsT] = NoneType,
name: str | None = None,
Expand All @@ -193,10 +189,7 @@ def __init__(
model: models.Model | models.KnownModelName | str | None = None,
*,
output_type: OutputSpec[OutputDataT] = str,
instructions: str
| _system_prompt.SystemPromptFunc[AgentDepsT]
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
| None = None,
instructions: Instructions[AgentDepsT] = None,
system_prompt: str | Sequence[str] = (),
deps_type: type[AgentDepsT] = NoneType,
name: str | None = None,
Expand All @@ -220,10 +213,7 @@ def __init__(
model: models.Model | models.KnownModelName | str | None = None,
*,
output_type: OutputSpec[OutputDataT] = str,
instructions: str
| _system_prompt.SystemPromptFunc[AgentDepsT]
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
| None = None,
instructions: Instructions[AgentDepsT] = None,
system_prompt: str | Sequence[str] = (),
deps_type: type[AgentDepsT] = NoneType,
name: str | None = None,
Expand Down Expand Up @@ -322,16 +312,7 @@ def __init__(
self._output_schema = _output.OutputSchema[OutputDataT].build(output_type, default_mode=default_output_mode)
self._output_validators = []

self._instructions = ''
self._instructions_functions = []
if isinstance(instructions, str | Callable):
instructions = [instructions]
for instruction in instructions or []:
if isinstance(instruction, str):
self._instructions += instruction + '\n'
else:
self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction))
self._instructions = self._instructions.strip() or None
self._instructions = self._normalize_instructions(instructions)

self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
self._system_prompt_functions = []
Expand Down Expand Up @@ -371,6 +352,9 @@ def __init__(
self._override_tools: ContextVar[
_utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]]
] = ContextVar('_override_tools', default=None)
self._override_instructions: ContextVar[
_utils.Option[list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]]
] = ContextVar('_override_instructions', default=None)

self._enter_lock = Lock()
self._entered_count = 0
Expand Down Expand Up @@ -593,10 +577,12 @@ async def main():
model_settings = merge_model_settings(merged_settings, model_settings)
usage_limits = usage_limits or _usage.UsageLimits()

instructions_literal, instructions_functions = self._get_instructions()

async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
parts = [
self._instructions,
*[await func.run(run_context) for func in self._instructions_functions],
instructions_literal,
*[await func.run(run_context) for func in instructions_functions],
]

model_profile = model_used.profile
Expand Down Expand Up @@ -634,11 +620,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
get_instructions=get_instructions,
instrumentation_settings=instrumentation_settings,
)

start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
deferred_tool_results=deferred_tool_results,
instructions=self._instructions,
instructions_functions=self._instructions_functions,
instructions=instructions_literal,
instructions_functions=instructions_functions,
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
Expand Down Expand Up @@ -690,6 +677,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
def _run_span_end_attributes(
self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings
):
literal_instructions, _ = self._get_instructions()

if settings.version == 1:
attrs = {
'all_messages_events': json.dumps(
Expand All @@ -702,7 +691,7 @@ def _run_span_end_attributes(
else:
attrs = {
'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(state.message_history)),
**settings.system_instructions_attributes(self._instructions),
**settings.system_instructions_attributes(literal_instructions),
}

return {
Expand All @@ -727,8 +716,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -738,6 +728,7 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
if _utils.is_set(deps):
deps_token = self._override_deps.set(_utils.Some(deps))
Expand All @@ -759,6 +750,12 @@ def override(
else:
tools_token = None

if _utils.is_set(instructions):
normalized_instructions = self._normalize_instructions(instructions)
instructions_token = self._override_instructions.set(_utils.Some(normalized_instructions))
else:
instructions_token = None

try:
yield
finally:
Expand All @@ -770,6 +767,8 @@ def override(
self._override_toolsets.reset(toolsets_token)
if tools_token is not None:
self._override_tools.reset(tools_token)
if instructions_token is not None:
self._override_instructions.reset(instructions_token)

@overload
def instructions(
Expand Down Expand Up @@ -830,12 +829,12 @@ async def async_instructions(ctx: RunContext[str]) -> str:
def decorator(
func_: _system_prompt.SystemPromptFunc[AgentDepsT],
) -> _system_prompt.SystemPromptFunc[AgentDepsT]:
self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_))
self._instructions.append(func_)
return func_

return decorator
else:
self._instructions_functions.append(_system_prompt.SystemPromptRunner(func))
self._instructions.append(func)
return func

@overload
Expand Down Expand Up @@ -1276,6 +1275,34 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
else:
return deps

def _normalize_instructions(
self,
instructions: Instructions[AgentDepsT],
) -> list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]:
if instructions is None:
return []
if isinstance(instructions, str) or callable(instructions):
return [instructions]
return list(instructions)

def _get_instructions(
self,
) -> tuple[str | None, list[_system_prompt.SystemPromptRunner[AgentDepsT]]]:
override_instructions = self._override_instructions.get()
instructions = override_instructions.value if override_instructions else self._instructions

literal_parts: list[str] = []
functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []

for instruction in instructions:
if isinstance(instruction, str):
literal_parts.append(instruction)
else:
functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](instruction))

literal = '\n'.join(literal_parts).strip() or None
return literal, functions

def _get_toolset(
self,
output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET,
Expand Down
13 changes: 12 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .. import (
_agent_graph,
_system_prompt,
_utils,
exceptions,
messages as _messages,
Expand Down Expand Up @@ -60,6 +61,14 @@
"""A function that receives agent [`RunContext`][pydantic_ai.tools.RunContext] and an async iterable of events from the model's streaming response and the agent's execution of tools."""


Instructions = (
str
| _system_prompt.SystemPromptFunc[AgentDepsT]
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
| None
)


class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
"""Abstract superclass for [`Agent`][pydantic_ai.agent.Agent], [`WrapperAgent`][pydantic_ai.agent.WrapperAgent], and your own custom agent implementations."""

Expand Down Expand Up @@ -681,8 +690,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -692,6 +702,7 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
raise NotImplementedError
yield
Expand Down
14 changes: 11 additions & 3 deletions pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ToolFuncEither,
)
from ..toolsets import AbstractToolset
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT


class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT]):
Expand Down Expand Up @@ -214,8 +214,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -225,6 +226,13 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
with self.wrapped.override(deps=deps, model=model, toolsets=toolsets, tools=tools):
with self.wrapped.override(
deps=deps,
model=model,
toolsets=toolsets,
tools=tools,
instructions=instructions,
):
yield
13 changes: 11 additions & 2 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
usage as _usage,
)
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
from pydantic_ai.agent.abstract import Instructions
from pydantic_ai.exceptions import UserError
from pydantic_ai.models import Model
from pydantic_ai.output import OutputDataT, OutputSpec
Expand Down Expand Up @@ -704,8 +705,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -715,11 +717,18 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
if _utils.is_set(model) and not isinstance(model, (DBOSModel)):
raise UserError(
'Non-DBOS model cannot be contextually overridden inside a DBOS workflow, it must be set at agent creation time.'
)

with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools):
with super().override(
deps=deps,
model=model,
toolsets=toolsets,
tools=tools,
instructions=instructions,
):
yield
15 changes: 12 additions & 3 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
models,
usage as _usage,
)
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent
from pydantic_ai.agent.abstract import Instructions, RunOutputDataT
from pydantic_ai.exceptions import UserError
from pydantic_ai.models import Model
from pydantic_ai.output import OutputDataT, OutputSpec
Expand Down Expand Up @@ -748,8 +749,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -759,6 +761,7 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
if workflow.in_workflow():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we want to prevent overriding instructions within a workflow? i definitely want to be able to do that since the whole point is we can use an LLM to optimize the system prompt, we're just using override as a way to propagate those instructions.

i'm guessing this is just being defensive about trying to enforce deterministic behavior?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overriding the instructions is fine with Temporal! The other checks are there because those values get "baked into" the TemporalModel through dedicated activities.

if _utils.is_set(model):
Expand All @@ -774,5 +777,11 @@ def override(
'Tools cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.'
)

with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools):
with super().override(
deps=deps,
model=model,
toolsets=toolsets,
tools=tools,
instructions=instructions,
):
yield
Loading