12
12
class GuidedRewriteAgent (RewriteAgent ):
13
13
name : str = "guided_agent"
14
14
15
+ def __init__ (self , * args , ** kwargs ):
16
+ super ().__init__ (* args , ** kwargs )
17
+ self .logger .set_no_live ()
18
+
15
19
def try_rewrite (self , task_name ):
16
20
# make a copy of the env for the llm
17
21
from ipdb import set_trace
@@ -38,54 +42,106 @@ def try_rewrite(self, task_name):
38
42
return info .done
39
43
40
44
def run (self , task_name = None , debug = False ):
41
- self .logger .level = logging .DEBUG
42
- self .llm .logger = DebugGymLogger (
43
- name = "LLM" , level = logging .ERROR , log_dir = self .logger .log_file .parent
44
- )
45
- self .human = LLM .instantiate (llm_name = "human" , logger = self .logger )
46
-
47
- self .history .reset ()
48
- info = self .env .reset (options = {"task_name" : task_name })
49
- # initial state does not have prompt and response
50
- self .history .step (info , None )
51
-
52
- if info .done is True :
53
- # msg = "Environment started with entrypoint passing without errors."
54
- return True
55
-
56
- highscore = info .score
57
-
58
- for step in self .logger .tqdm (range (self .config ["max_steps" ])):
59
- highscore = max (highscore , info .score )
60
- self .logger .info (
61
- f"Score: { info .score } /{ info .max_score } ({ info .score / info .max_score :.1%} ) [Best: { highscore } ]"
45
+ step = 0
46
+ max_steps = self .config ["max_steps" ]
47
+ try :
48
+ self .logger .level = logging .DEBUG
49
+ self .logger .icon = "👤"
50
+ self .llm .logger = DebugGymLogger (
51
+ name = "LLM" , level = logging .ERROR , log_dir = self .logger .log_file .parent
62
52
)
53
+ self .llm .logger .icon = "🤖"
54
+ self .human = LLM .instantiate (llm_name = "human" , logger = self .logger )
55
+
56
+ self .history .reset ()
57
+ info = self .env .reset (options = {"task_name" : task_name })
58
+ # initial state does not have prompt and response
59
+ self .history .step (info , None )
60
+
61
+ if info .done is True :
62
+ # msg = "Environment started with entrypoint passing without errors."self.logger.report_progress(
63
+ self .logger .report_progress (
64
+ problem_id = task_name ,
65
+ step = 1 ,
66
+ total_steps = 1 ,
67
+ score = info .score ,
68
+ max_score = info .max_score ,
69
+ status = "resolved" ,
70
+ )
71
+ return True
63
72
64
- llm_done = self .try_rewrite (task_name )
65
- if llm_done :
66
- msg = f"*** The rewrite-only agent with { self .llm .model_name } managed to solve the task with the current context. ***"
67
- self .logger .info (colored (msg , "green" ))
68
- break
69
- else :
70
- msg = f"*** The rewrite-only agent with { self .llm .model_name } failed to solve the task with the current context. ***"
71
- self .logger .info (colored (msg , "red" ))
72
-
73
- # If the LLM did not manage to solve the task, we continue with the guided approach.
74
- prompt = self .build_prompt (info )
75
- human_response = self .human (prompt , info .tools )
76
-
77
- if debug :
78
- breakpoint ()
79
-
80
- # step the environment with the human response
81
- info = self .env .step (human_response .tool )
82
- # log the human response
83
- self .history .step (info , human_response )
73
+ highscore = info .score
84
74
85
- if info .done :
75
+ for step in range (max_steps ):
76
+ self .logger .info (f"\n { '=' * 20 } STEP { step + 1 } { '=' * 20 } \n " )
77
+ highscore = max (highscore , info .score )
86
78
self .logger .info (
87
- "You managed to provide the patch that solves the task before the LLM. Congrats! "
79
+ f"Step: { step } | Score: { info . score } / { info . max_score } ( { info . score / info . max_score :.1% } ) [Best: { highscore } ] "
88
80
)
89
- break
90
81
91
- return info .done
82
+ llm_done = self .try_rewrite (task_name )
83
+ if llm_done :
84
+ msg = f"[green]*** The rewrite-only agent with { self .llm .model_name } managed to solve the task with the current context. ***[/green]"
85
+ self .logger .info (msg )
86
+ break
87
+ else :
88
+ msg = f"[red]*** The rewrite-only agent with { self .llm .model_name } failed to solve the task with the current context. ***[/red]"
89
+ self .logger .info (msg )
90
+
91
+ # If the LLM did not manage to solve the task, we continue with the guided approach.
92
+ prompt = self .build_prompt (info )
93
+ human_response = self .human (prompt , info .tools )
94
+
95
+ if debug :
96
+ breakpoint ()
97
+
98
+ # step the environment with the human response
99
+ info = self .env .step (human_response .tool )
100
+ # log the human response
101
+ self .history .step (info , human_response )
102
+
103
+ if info .done :
104
+ self .logger .info (
105
+ "You managed to provide the patch that solves the task before the LLM. Congrats!"
106
+ )
107
+ # early stop, set current step and total steps to be the same
108
+ self .logger .report_progress (
109
+ problem_id = task_name ,
110
+ step = step + 1 ,
111
+ total_steps = step + 1 ,
112
+ score = info .score ,
113
+ max_score = info .max_score ,
114
+ status = "resolved" if info .done else "unresolved" ,
115
+ )
116
+ break
117
+ # keep progress bar running until max_steps is reached
118
+ self .logger .report_progress (
119
+ problem_id = task_name ,
120
+ step = step + 1 ,
121
+ total_steps = max_steps + 1 ,
122
+ score = info .score ,
123
+ max_score = info .max_score ,
124
+ status = "running" ,
125
+ )
126
+ # max_steps was reached, task was either resolved or unresolved
127
+ self .logger .report_progress (
128
+ problem_id = task_name ,
129
+ step = step + 1 ,
130
+ total_steps = step + 1 ,
131
+ score = info .score ,
132
+ max_score = info .max_score ,
133
+ status = "resolved" if info .done else "unresolved" ,
134
+ )
135
+
136
+ return info .done
137
+ except Exception :
138
+ # report any error that happens during the run
139
+ self .logger .report_progress (
140
+ problem_id = task_name ,
141
+ step = step + 1 ,
142
+ total_steps = step + 1 ,
143
+ score = info .score if info else 0 ,
144
+ max_score = info .max_score if info else 1 ,
145
+ status = "error" ,
146
+ )
147
+ raise
0 commit comments