diff --git a/adk/prebuilt/planexecute/plan_execute.go b/adk/prebuilt/planexecute/plan_execute.go index 06878f15..55c180da 100644 --- a/adk/prebuilt/planexecute/plan_execute.go +++ b/adk/prebuilt/planexecute/plan_execute.go @@ -48,10 +48,10 @@ type Plan interface { json.Unmarshaler } -// NewPlan is a function type that creates a new Plan instance. -type NewPlan func(ctx context.Context) Plan +// PlanFactory is a function type that creates a new Plan instance. +type PlanFactory func(ctx context.Context) Plan -// defaultPlan is the default implementation of the Plan interface. +// DefaultPlan is the default implementation of the Plan interface. // // JSON Schema: // @@ -68,27 +68,27 @@ type NewPlan func(ctx context.Context) Plan // }, // "required": ["steps"] // } -type defaultPlan struct { +type DefaultPlan struct { // Steps contains the ordered list of actions to be taken. // Each step should be clear, actionable, and arranged in a logical sequence. Steps []string `json:"steps"` } // FirstStep returns the first step in the plan or an empty string if no steps exist. -func (p *defaultPlan) FirstStep() string { +func (p *DefaultPlan) FirstStep() string { if len(p.Steps) == 0 { return "" } return p.Steps[0] } -func (p *defaultPlan) MarshalJSON() ([]byte, error) { - type planTyp defaultPlan +func (p *DefaultPlan) MarshalJSON() ([]byte, error) { + type planTyp DefaultPlan return sonic.Marshal((*planTyp)(p)) } -func (p *defaultPlan) UnmarshalJSON(bytes []byte) error { - type planTyp defaultPlan +func (p *DefaultPlan) UnmarshalJSON(bytes []byte) error { + type planTyp DefaultPlan return sonic.Unmarshal(bytes, (*planTyp)(p)) } @@ -265,24 +265,23 @@ type PlannerConfig struct { // Optional. If not provided, PlanToolInfo will be used as the default. ToolInfo *schema.ToolInfo - // GenInputFn is a function that generates the input messages for the planner. - // Optional. If not provided, defaultGenPlannerInputFn will be used. + // GenInputFn generates input messages for the planner. + // Optional. Defaults to using PlannerPrompt as the template to render model input messages. GenInputFn GenPlannerModelInputFn - // NewPlan creates a new Plan instance for JSON. - // The returned Plan will be used to unmarshal the model-generated JSON output. - // Optional. If not provided, defaultNewPlan will be used. - NewPlan NewPlan + // Factory creates Plan instances for JSON unmarshaling. + // Optional. Defaults to creating DefaultPlan instances. + Factory PlanFactory } // GenPlannerModelInputFn is a function type that generates input messages for the planner. -type GenPlannerModelInputFn func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) +type GenPlannerModelInputFn func(ctx context.Context, userInput []adk.Message, cfg *PlannerConfig) ([]adk.Message, error) -func defaultNewPlan(ctx context.Context) Plan { - return &defaultPlan{} +func defaultPlanFactory(ctx context.Context) Plan { + return &DefaultPlan{} } -func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) { +func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message, _ *PlannerConfig) ([]adk.Message, error) { msgs, err := PlannerPrompt.Format(ctx, map[string]any{ "input": userInput, }) @@ -293,10 +292,11 @@ func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message) ([]a } type planner struct { + cfg *PlannerConfig toolCall bool chatModel model.BaseChatModel genInputFn GenPlannerModelInputFn - newPlan NewPlan + factory PlanFactory } func (p *planner) Name(_ context.Context) string { @@ -333,7 +333,7 @@ func (p *planner) Run(ctx context.Context, input *adk.AgentInput, generator.Close() }() - msgs, err := p.genInputFn(ctx, input.Messages) + msgs, err := p.genInputFn(ctx, input.Messages, p.cfg) if err != nil { generator.Send(&adk.AgentEvent{Err: err}) return @@ -401,7 +401,7 @@ func (p *planner) Run(ctx context.Context, input *adk.AgentInput, } else { planJSON = msg.Content } - plan := p.newPlan(ctx) + plan := p.factory(ctx) err = plan.UnmarshalJSON([]byte(planJSON)) if err != nil { err = fmt.Errorf("unmarshal plan error: %w", err) @@ -440,34 +440,34 @@ func NewPlanner(_ context.Context, cfg *PlannerConfig) (adk.Agent, error) { return nil, err } } - inputFn := cfg.GenInputFn if inputFn == nil { inputFn = defaultGenPlannerInputFn } - planParser := cfg.NewPlan - if planParser == nil { - planParser = defaultNewPlan + factory := cfg.Factory + if factory == nil { + factory = defaultPlanFactory } return &planner{ + cfg: cfg, toolCall: toolCall, chatModel: chatModel, genInputFn: inputFn, - newPlan: planParser, + factory: factory, }, nil } -// ExecutionContext is the input information for the executor and the planner. +// ExecutionContext is the input information for the executor and re-planner. type ExecutionContext struct { UserInput []adk.Message Plan Plan ExecutedSteps []ExecutedStep } -// GenModelInputFn is a function that generates the input messages for the executor and the planner. -type GenModelInputFn func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) +// GenExecutorModelInputFn is a function that generates the input messages for the executor. +type GenExecutorModelInputFn func(ctx context.Context, in *ExecutionContext, cfg *ExecutorConfig) ([]adk.Message, error) // ExecutorConfig provides configuration options for creating an executor agent. type ExecutorConfig struct { @@ -482,9 +482,9 @@ type ExecutorConfig struct { // Optional. Defaults to 20. MaxIterations int - // GenInputFn generates the input messages for the Executor. - // Optional. If not provided, defaultGenExecutorInputFn will be used. - GenInputFn GenModelInputFn + // GenInputFn generates input messages for the executor. + // Optional. Defaults to using ExecutorPrompt as the template to render model input messages. + GenInputFn GenExecutorModelInputFn } type ExecutedStep struct { @@ -525,7 +525,7 @@ func NewExecutor(ctx context.Context, cfg *ExecutorConfig) (adk.Agent, error) { ExecutedSteps: executedSteps_, } - msgs, err := genInputFn(ctx, in) + msgs, err := genInputFn(ctx, in, cfg) if err != nil { return nil, err } @@ -549,7 +549,7 @@ func NewExecutor(ctx context.Context, cfg *ExecutorConfig) (adk.Agent, error) { return agent, nil } -func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) { +func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext, _ *ExecutorConfig) ([]adk.Message, error) { planContent, err := in.Plan.MarshalJSON() if err != nil { @@ -570,14 +570,18 @@ func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext) ([]adk } type replanner struct { + cfg *ReplannerConfig chatModel model.ToolCallingChatModel planTool *schema.ToolInfo respondTool *schema.ToolInfo - genInputFn GenModelInputFn - newPlan NewPlan + genInputFn GenReplannerModelInputFn + factory PlanFactory } +// GenReplannerModelInputFn is a function that generates the input messages for the re-planner. +type GenReplannerModelInputFn func(ctx context.Context, in *ExecutionContext, conf *ReplannerConfig) ([]adk.Message, error) + type ReplannerConfig struct { // ChatModel is the model that supports tool calling capabilities. // It will be configured with PlanTool and RespondTool to generate updated plans or responses. @@ -591,14 +595,13 @@ type ReplannerConfig struct { // Optional. If not provided, the default RespondToolInfo will be used. RespondTool *schema.ToolInfo - // GenInputFn generates the input messages for the Replanner. - // Optional. If not provided, buildGenReplannerInputFn will be used. - GenInputFn GenModelInputFn + // GenInputFn generates input messages for the re-planner. + // Optional. Defaults to using ReplannerPrompt as the template to render model input messages. + GenInputFn GenReplannerModelInputFn - // NewPlan creates a new Plan instance. - // The returned Plan will be used to unmarshal the model-generated JSON output from PlanTool. - // Optional. If not provided, defaultNewPlan will be used. - NewPlan NewPlan + // Factory creates Plan instances for JSON unmarshaling. + // Optional. Defaults to creating DefaultPlan instances. + Factory PlanFactory } // formatInput formats the input messages into a string. @@ -667,11 +670,8 @@ func (r *replanner) genInput(ctx context.Context) ([]adk.Message, error) { Plan: plan_, ExecutedSteps: executedSteps_, } - genInputFn := r.genInputFn - if genInputFn == nil { - genInputFn = buildGenReplannerInputFn(r.planTool.Name, r.respondTool.Name) - } - msgs, err := genInputFn(ctx, in) + + msgs, err := r.genInputFn(ctx, in, r.cfg) if err != nil { return nil, err } @@ -787,7 +787,7 @@ func (r *replanner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.Age return } - plan_ := r.newPlan(ctx) + plan_ := r.factory(ctx) err = plan_.UnmarshalJSON([]byte(planMsg.ToolCalls[0].Function.Arguments)) if err != nil { err = fmt.Errorf("unmarshal plan error: %w", err) @@ -801,25 +801,34 @@ func (r *replanner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.Age return iterator } -func buildGenReplannerInputFn(planToolName, respondToolName string) GenModelInputFn { - return func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) { - planContent, err := in.Plan.MarshalJSON() - if err != nil { - return nil, err - } - msgs, err := ReplannerPrompt.Format(ctx, map[string]any{ - "plan": string(planContent), - "input": formatInput(in.UserInput), - "executed_steps": formatExecutedSteps(in.ExecutedSteps), - "plan_tool": planToolName, - "respond_tool": respondToolName, - }) - if err != nil { - return nil, err - } +func defaultGenReplannerInputFn(ctx context.Context, in *ExecutionContext, cfg *ReplannerConfig) ([]adk.Message, error) { + planContent, err := in.Plan.MarshalJSON() + if err != nil { + return nil, err + } - return msgs, nil + planToolName := PlanToolInfo.Name + if cfg.PlanTool != nil { + planToolName = cfg.PlanTool.Name } + + respondToolName := RespondToolInfo.Name + if cfg.RespondTool != nil { + respondToolName = cfg.RespondTool.Name + } + + msgs, err := ReplannerPrompt.Format(ctx, map[string]any{ + "plan": string(planContent), + "input": formatInput(in.UserInput), + "executed_steps": formatExecutedSteps(in.ExecutedSteps), + "plan_tool": planToolName, + "respond_tool": respondToolName, + }) + if err != nil { + return nil, err + } + + return msgs, nil } func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) { @@ -838,17 +847,23 @@ func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) { return nil, err } - planParser := cfg.NewPlan - if planParser == nil { - planParser = defaultNewPlan + factory := cfg.Factory + if factory == nil { + factory = defaultPlanFactory + } + + genInputFn := cfg.GenInputFn + if genInputFn == nil { + genInputFn = defaultGenReplannerInputFn } return &replanner{ + cfg: cfg, chatModel: chatModel, planTool: planTool, respondTool: respondTool, - genInputFn: cfg.GenInputFn, - newPlan: planParser, + genInputFn: genInputFn, + factory: factory, }, nil } diff --git a/adk/prebuilt/planexecute/plan_execute_test.go b/adk/prebuilt/planexecute/plan_execute_test.go index 6b30d65f..1b67845c 100644 --- a/adk/prebuilt/planexecute/plan_execute_test.go +++ b/adk/prebuilt/planexecute/plan_execute_test.go @@ -127,10 +127,10 @@ func TestPlannerRunWithFormattedOutput(t *testing.T) { event, ok = iterator.Next() assert.False(t, ok) - plan := defaultNewPlan(ctx) + plan := defaultPlanFactory(ctx) err = plan.UnmarshalJSON([]byte(msg.Content)) assert.NoError(t, err) - plan_ := plan.(*defaultPlan) + plan_ := plan.(*DefaultPlan) assert.Equal(t, 3, len(plan_.Steps)) assert.Equal(t, "Step 1", plan_.Steps[0]) assert.Equal(t, "Step 2", plan_.Steps[1]) @@ -194,10 +194,10 @@ func TestPlannerRunWithToolCalling(t *testing.T) { _, ok = iterator.Next() assert.False(t, ok) - plan := defaultNewPlan(ctx) + plan := defaultPlanFactory(ctx) err = plan.UnmarshalJSON([]byte(msg.Content)) assert.NoError(t, err) - plan_ := plan.(*defaultPlan) + plan_ := plan.(*DefaultPlan) assert.NoError(t, err) assert.Equal(t, 3, len(plan_.Steps)) assert.Equal(t, "Step 1", plan_.Steps[0]) @@ -244,7 +244,7 @@ func TestExecutorRun(t *testing.T) { mockToolCallingModel := mockModel.NewMockToolCallingChatModel(ctrl) // Store a plan in the session - plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} + plan := &DefaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} adk.AddSessionValue(ctx, PlanSessionKey, plan) // Set up expectations for the mock model @@ -388,7 +388,7 @@ func TestReplannerRunWithPlan(t *testing.T) { assert.NoError(t, err) // Store necessary values in the session - plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} + plan := &DefaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} rp, err = agentOutputSessionKVs(ctx, rp) assert.NoError(t, err) @@ -416,7 +416,7 @@ func TestReplannerRunWithPlan(t *testing.T) { // Verify the updated plan was stored in the session planValue, ok := kvs[PlanSessionKey] assert.True(t, ok) - updatedPlan, ok := planValue.(*defaultPlan) + updatedPlan, ok := planValue.(*DefaultPlan) assert.True(t, ok) assert.Equal(t, 2, len(updatedPlan.Steps)) assert.Equal(t, "Updated Step 1", updatedPlan.Steps[0]) @@ -487,7 +487,7 @@ func TestReplannerRunWithRespond(t *testing.T) { assert.NoError(t, err) // Store necessary values in the session - plan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} + plan := &DefaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} // Run the replanner runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: rp}) @@ -575,9 +575,9 @@ func TestPlanExecuteAgentWithReplan(t *testing.T) { mockReplanner.EXPECT().Description(gomock.Any()).Return("a replanner agent").AnyTimes() // Create a plan - originalPlan := &defaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} + originalPlan := &DefaultPlan{Steps: []string{"Step 1", "Step 2", "Step 3"}} // Create an updated plan with fewer steps (after replanning) - updatedPlan := &defaultPlan{Steps: []string{"Updated Step 2", "Updated Step 3"}} + updatedPlan := &DefaultPlan{Steps: []string{"Updated Step 2", "Updated Step 3"}} // Create execute result originalExecuteResult := "Execution result for Step 1" updatedExecuteResult := "Execution result for Updated Step 2" @@ -613,7 +613,7 @@ func TestPlanExecuteAgentWithReplan(t *testing.T) { iterator, generator := adk.NewAsyncIteratorPair[*adk.AgentEvent]() plan, _ := adk.GetSessionValue(ctx, PlanSessionKey) - currentPlan := plan.(*defaultPlan) + currentPlan := plan.(*DefaultPlan) var msg adk.Message // Check if this is the first replanning (original plan has 3 steps) if len(currentPlan.Steps) == 3 { @@ -639,7 +639,7 @@ func TestPlanExecuteAgentWithReplan(t *testing.T) { // First call: Update the plan // Get the current plan from the session plan, _ := adk.GetSessionValue(ctx, PlanSessionKey) - currentPlan := plan.(*defaultPlan) + currentPlan := plan.(*DefaultPlan) // Check if this is the first replanning (original plan has 3 steps) if len(currentPlan.Steps) == 3 {