Skip to content

Commit 16cb97c

Browse files
committed
feat: adjust GenInputFn
feat: change NewPlan to PlanFactory feat: adjust comments
1 parent fba3e63 commit 16cb97c

File tree

2 files changed

+99
-84
lines changed

2 files changed

+99
-84
lines changed

adk/prebuilt/planexecute/plan_execute.go

Lines changed: 87 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ type Plan interface {
4848
json.Unmarshaler
4949
}
5050

51-
// NewPlan is a function type that creates a new Plan instance.
52-
type NewPlan func(ctx context.Context) Plan
51+
// PlanFactory is a function type that creates a new Plan instance.
52+
type PlanFactory func(ctx context.Context) Plan
5353

54-
// defaultPlan is the default implementation of the Plan interface.
54+
// DefaultPlan is the default implementation of the Plan interface.
5555
//
5656
// JSON Schema:
5757
//
@@ -68,27 +68,27 @@ type NewPlan func(ctx context.Context) Plan
6868
// },
6969
// "required": ["steps"]
7070
// }
71-
type defaultPlan struct {
71+
type DefaultPlan struct {
7272
// Steps contains the ordered list of actions to be taken.
7373
// Each step should be clear, actionable, and arranged in a logical sequence.
7474
Steps []string `json:"steps"`
7575
}
7676

7777
// FirstStep returns the first step in the plan or an empty string if no steps exist.
78-
func (p *defaultPlan) FirstStep() string {
78+
func (p *DefaultPlan) FirstStep() string {
7979
if len(p.Steps) == 0 {
8080
return ""
8181
}
8282
return p.Steps[0]
8383
}
8484

85-
func (p *defaultPlan) MarshalJSON() ([]byte, error) {
86-
type planTyp defaultPlan
85+
func (p *DefaultPlan) MarshalJSON() ([]byte, error) {
86+
type planTyp DefaultPlan
8787
return sonic.Marshal((*planTyp)(p))
8888
}
8989

90-
func (p *defaultPlan) UnmarshalJSON(bytes []byte) error {
91-
type planTyp defaultPlan
90+
func (p *DefaultPlan) UnmarshalJSON(bytes []byte) error {
91+
type planTyp DefaultPlan
9292
return sonic.Unmarshal(bytes, (*planTyp)(p))
9393
}
9494

@@ -265,24 +265,23 @@ type PlannerConfig struct {
265265
// Optional. If not provided, PlanToolInfo will be used as the default.
266266
ToolInfo *schema.ToolInfo
267267

268-
// GenInputFn is a function that generates the input messages for the planner.
269-
// Optional. If not provided, defaultGenPlannerInputFn will be used.
268+
// GenInputFn generates input messages for the planner.
269+
// Optional. Defaults to using PlannerPrompt as the template to render model input messages.
270270
GenInputFn GenPlannerModelInputFn
271271

272-
// NewPlan creates a new Plan instance for JSON.
273-
// The returned Plan will be used to unmarshal the model-generated JSON output.
274-
// Optional. If not provided, defaultNewPlan will be used.
275-
NewPlan NewPlan
272+
// Factory creates Plan instances for JSON unmarshaling.
273+
// Optional. Defaults to creating DefaultPlan instances.
274+
Factory PlanFactory
276275
}
277276

278277
// GenPlannerModelInputFn is a function type that generates input messages for the planner.
279-
type GenPlannerModelInputFn func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error)
278+
type GenPlannerModelInputFn func(ctx context.Context, userInput []adk.Message, cfg *PlannerConfig) ([]adk.Message, error)
280279

281-
func defaultNewPlan(ctx context.Context) Plan {
282-
return &defaultPlan{}
280+
func defaultPlanFactory(ctx context.Context) Plan {
281+
return &DefaultPlan{}
283282
}
284283

285-
func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) {
284+
func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message, _ *PlannerConfig) ([]adk.Message, error) {
286285
msgs, err := PlannerPrompt.Format(ctx, map[string]any{
287286
"input": userInput,
288287
})
@@ -293,10 +292,11 @@ func defaultGenPlannerInputFn(ctx context.Context, userInput []adk.Message) ([]a
293292
}
294293

295294
type planner struct {
295+
cfg *PlannerConfig
296296
toolCall bool
297297
chatModel model.BaseChatModel
298298
genInputFn GenPlannerModelInputFn
299-
newPlan NewPlan
299+
factory PlanFactory
300300
}
301301

302302
func (p *planner) Name(_ context.Context) string {
@@ -333,7 +333,7 @@ func (p *planner) Run(ctx context.Context, input *adk.AgentInput,
333333
generator.Close()
334334
}()
335335

336-
msgs, err := p.genInputFn(ctx, input.Messages)
336+
msgs, err := p.genInputFn(ctx, input.Messages, p.cfg)
337337
if err != nil {
338338
generator.Send(&adk.AgentEvent{Err: err})
339339
return
@@ -401,7 +401,7 @@ func (p *planner) Run(ctx context.Context, input *adk.AgentInput,
401401
} else {
402402
planJSON = msg.Content
403403
}
404-
plan := p.newPlan(ctx)
404+
plan := p.factory(ctx)
405405
err = plan.UnmarshalJSON([]byte(planJSON))
406406
if err != nil {
407407
err = fmt.Errorf("unmarshal plan error: %w", err)
@@ -440,34 +440,34 @@ func NewPlanner(_ context.Context, cfg *PlannerConfig) (adk.Agent, error) {
440440
return nil, err
441441
}
442442
}
443-
444443
inputFn := cfg.GenInputFn
445444
if inputFn == nil {
446445
inputFn = defaultGenPlannerInputFn
447446
}
448447

449-
planParser := cfg.NewPlan
450-
if planParser == nil {
451-
planParser = defaultNewPlan
448+
factory := cfg.Factory
449+
if factory == nil {
450+
factory = defaultPlanFactory
452451
}
453452

454453
return &planner{
454+
cfg: cfg,
455455
toolCall: toolCall,
456456
chatModel: chatModel,
457457
genInputFn: inputFn,
458-
newPlan: planParser,
458+
factory: factory,
459459
}, nil
460460
}
461461

462-
// ExecutionContext is the input information for the executor and the planner.
462+
// ExecutionContext is the input information for the executor and re-planner.
463463
type ExecutionContext struct {
464464
UserInput []adk.Message
465465
Plan Plan
466466
ExecutedSteps []ExecutedStep
467467
}
468468

469-
// GenModelInputFn is a function that generates the input messages for the executor and the planner.
470-
type GenModelInputFn func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error)
469+
// GenExecutorModelInputFn is a function that generates the input messages for the executor.
470+
type GenExecutorModelInputFn func(ctx context.Context, in *ExecutionContext, cfg *ExecutorConfig) ([]adk.Message, error)
471471

472472
// ExecutorConfig provides configuration options for creating an executor agent.
473473
type ExecutorConfig struct {
@@ -482,9 +482,9 @@ type ExecutorConfig struct {
482482
// Optional. Defaults to 20.
483483
MaxIterations int
484484

485-
// GenInputFn generates the input messages for the Executor.
486-
// Optional. If not provided, defaultGenExecutorInputFn will be used.
487-
GenInputFn GenModelInputFn
485+
// GenInputFn generates input messages for the executor.
486+
// Optional. Defaults to using ExecutorPrompt as the template to render model input messages.
487+
GenInputFn GenExecutorModelInputFn
488488
}
489489

490490
type ExecutedStep struct {
@@ -525,7 +525,7 @@ func NewExecutor(ctx context.Context, cfg *ExecutorConfig) (adk.Agent, error) {
525525
ExecutedSteps: executedSteps_,
526526
}
527527

528-
msgs, err := genInputFn(ctx, in)
528+
msgs, err := genInputFn(ctx, in, cfg)
529529
if err != nil {
530530
return nil, err
531531
}
@@ -549,7 +549,7 @@ func NewExecutor(ctx context.Context, cfg *ExecutorConfig) (adk.Agent, error) {
549549
return agent, nil
550550
}
551551

552-
func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) {
552+
func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext, _ *ExecutorConfig) ([]adk.Message, error) {
553553

554554
planContent, err := in.Plan.MarshalJSON()
555555
if err != nil {
@@ -570,14 +570,18 @@ func defaultGenExecutorInputFn(ctx context.Context, in *ExecutionContext) ([]adk
570570
}
571571

572572
type replanner struct {
573+
cfg *ReplannerConfig
573574
chatModel model.ToolCallingChatModel
574575
planTool *schema.ToolInfo
575576
respondTool *schema.ToolInfo
576577

577-
genInputFn GenModelInputFn
578-
newPlan NewPlan
578+
genInputFn GenReplannerModelInputFn
579+
factory PlanFactory
579580
}
580581

582+
// GenReplannerModelInputFn is a function that generates the input messages for the re-planner.
583+
type GenReplannerModelInputFn func(ctx context.Context, in *ExecutionContext, conf *ReplannerConfig) ([]adk.Message, error)
584+
581585
type ReplannerConfig struct {
582586
// ChatModel is the model that supports tool calling capabilities.
583587
// It will be configured with PlanTool and RespondTool to generate updated plans or responses.
@@ -591,14 +595,13 @@ type ReplannerConfig struct {
591595
// Optional. If not provided, the default RespondToolInfo will be used.
592596
RespondTool *schema.ToolInfo
593597

594-
// GenInputFn generates the input messages for the Replanner.
595-
// Optional. If not provided, buildGenReplannerInputFn will be used.
596-
GenInputFn GenModelInputFn
598+
// GenInputFn generates input messages for the re-planner.
599+
// Optional. Defaults to using ReplannerPrompt as the template to render model input messages.
600+
GenInputFn GenReplannerModelInputFn
597601

598-
// NewPlan creates a new Plan instance.
599-
// The returned Plan will be used to unmarshal the model-generated JSON output from PlanTool.
600-
// Optional. If not provided, defaultNewPlan will be used.
601-
NewPlan NewPlan
602+
// Factory creates Plan instances for JSON unmarshaling.
603+
// Optional. Defaults to creating DefaultPlan instances.
604+
Factory PlanFactory
602605
}
603606

604607
// formatInput formats the input messages into a string.
@@ -667,11 +670,8 @@ func (r *replanner) genInput(ctx context.Context) ([]adk.Message, error) {
667670
Plan: plan_,
668671
ExecutedSteps: executedSteps_,
669672
}
670-
genInputFn := r.genInputFn
671-
if genInputFn == nil {
672-
genInputFn = buildGenReplannerInputFn(r.planTool.Name, r.respondTool.Name)
673-
}
674-
msgs, err := genInputFn(ctx, in)
673+
674+
msgs, err := r.genInputFn(ctx, in, r.cfg)
675675
if err != nil {
676676
return nil, err
677677
}
@@ -787,7 +787,7 @@ func (r *replanner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.Age
787787
return
788788
}
789789

790-
plan_ := r.newPlan(ctx)
790+
plan_ := r.factory(ctx)
791791
err = plan_.UnmarshalJSON([]byte(planMsg.ToolCalls[0].Function.Arguments))
792792
if err != nil {
793793
err = fmt.Errorf("unmarshal plan error: %w", err)
@@ -801,25 +801,34 @@ func (r *replanner) Run(ctx context.Context, input *adk.AgentInput, _ ...adk.Age
801801
return iterator
802802
}
803803

804-
func buildGenReplannerInputFn(planToolName, respondToolName string) GenModelInputFn {
805-
return func(ctx context.Context, in *ExecutionContext) ([]adk.Message, error) {
806-
planContent, err := in.Plan.MarshalJSON()
807-
if err != nil {
808-
return nil, err
809-
}
810-
msgs, err := ReplannerPrompt.Format(ctx, map[string]any{
811-
"plan": string(planContent),
812-
"input": formatInput(in.UserInput),
813-
"executed_steps": formatExecutedSteps(in.ExecutedSteps),
814-
"plan_tool": planToolName,
815-
"respond_tool": respondToolName,
816-
})
817-
if err != nil {
818-
return nil, err
819-
}
804+
func defaultGenReplannerInputFn(ctx context.Context, in *ExecutionContext, cfg *ReplannerConfig) ([]adk.Message, error) {
805+
planContent, err := in.Plan.MarshalJSON()
806+
if err != nil {
807+
return nil, err
808+
}
820809

821-
return msgs, nil
810+
planToolName := PlanToolInfo.Name
811+
if cfg.PlanTool != nil {
812+
planToolName = cfg.PlanTool.Name
822813
}
814+
815+
respondToolName := RespondToolInfo.Name
816+
if cfg.RespondTool != nil {
817+
respondToolName = cfg.RespondTool.Name
818+
}
819+
820+
msgs, err := ReplannerPrompt.Format(ctx, map[string]any{
821+
"plan": string(planContent),
822+
"input": formatInput(in.UserInput),
823+
"executed_steps": formatExecutedSteps(in.ExecutedSteps),
824+
"plan_tool": planToolName,
825+
"respond_tool": respondToolName,
826+
})
827+
if err != nil {
828+
return nil, err
829+
}
830+
831+
return msgs, nil
823832
}
824833

825834
func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) {
@@ -838,17 +847,23 @@ func NewReplanner(_ context.Context, cfg *ReplannerConfig) (adk.Agent, error) {
838847
return nil, err
839848
}
840849

841-
planParser := cfg.NewPlan
842-
if planParser == nil {
843-
planParser = defaultNewPlan
850+
factory := cfg.Factory
851+
if factory == nil {
852+
factory = defaultPlanFactory
853+
}
854+
855+
genInputFn := cfg.GenInputFn
856+
if genInputFn == nil {
857+
genInputFn = defaultGenReplannerInputFn
844858
}
845859

846860
return &replanner{
861+
cfg: cfg,
847862
chatModel: chatModel,
848863
planTool: planTool,
849864
respondTool: respondTool,
850-
genInputFn: cfg.GenInputFn,
851-
newPlan: planParser,
865+
genInputFn: genInputFn,
866+
factory: factory,
852867
}, nil
853868
}
854869

0 commit comments

Comments
 (0)