Skip to content

Commit f3f8490

Browse files
committed
WIP
1 parent 5353750 commit f3f8490

File tree

2 files changed

+106
-46
lines changed

2 files changed

+106
-46
lines changed

debug_gym/agents/guided_agent.py

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
class GuidedRewriteAgent(RewriteAgent):
1313
name: str = "guided_agent"
1414

15+
def __init__(self, *args, **kwargs):
16+
super().__init__(*args, **kwargs)
17+
self.logger.set_no_live()
18+
1519
def try_rewrite(self, task_name):
1620
# make a copy of the env for the llm
1721
from ipdb import set_trace
@@ -38,54 +42,106 @@ def try_rewrite(self, task_name):
3842
return info.done
3943

4044
def run(self, task_name=None, debug=False):
41-
self.logger.level = logging.DEBUG
42-
self.llm.logger = DebugGymLogger(
43-
name="LLM", level=logging.ERROR, log_dir=self.logger.log_file.parent
44-
)
45-
self.human = LLM.instantiate(llm_name="human", logger=self.logger)
46-
47-
self.history.reset()
48-
info = self.env.reset(options={"task_name": task_name})
49-
# initial state does not have prompt and response
50-
self.history.step(info, None)
51-
52-
if info.done is True:
53-
# msg = "Environment started with entrypoint passing without errors."
54-
return True
55-
56-
highscore = info.score
57-
58-
for step in self.logger.tqdm(range(self.config["max_steps"])):
59-
highscore = max(highscore, info.score)
60-
self.logger.info(
61-
f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
45+
step = 0
46+
max_steps = self.config["max_steps"]
47+
try:
48+
self.logger.level = logging.DEBUG
49+
self.logger.icon = "👤"
50+
self.llm.logger = DebugGymLogger(
51+
name="LLM", level=logging.ERROR, log_dir=self.logger.log_file.parent
6252
)
53+
self.llm.logger.icon = "🤖"
54+
self.human = LLM.instantiate(llm_name="human", logger=self.logger)
55+
56+
self.history.reset()
57+
info = self.env.reset(options={"task_name": task_name})
58+
# initial state does not have prompt and response
59+
self.history.step(info, None)
60+
61+
if info.done is True:
62+
# msg = "Environment started with entrypoint passing without errors."self.logger.report_progress(
63+
self.logger.report_progress(
64+
problem_id=task_name,
65+
step=1,
66+
total_steps=1,
67+
score=info.score,
68+
max_score=info.max_score,
69+
status="resolved",
70+
)
71+
return True
6372

64-
llm_done = self.try_rewrite(task_name)
65-
if llm_done:
66-
msg = f"*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***"
67-
self.logger.info(colored(msg, "green"))
68-
break
69-
else:
70-
msg = f"*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***"
71-
self.logger.info(colored(msg, "red"))
72-
73-
# If the LLM did not manage to solve the task, we continue with the guided approach.
74-
prompt = self.build_prompt(info)
75-
human_response = self.human(prompt, info.tools)
76-
77-
if debug:
78-
breakpoint()
79-
80-
# step the environment with the human response
81-
info = self.env.step(human_response.tool)
82-
# log the human response
83-
self.history.step(info, human_response)
73+
highscore = info.score
8474

85-
if info.done:
75+
for step in range(max_steps):
76+
self.logger.info(f"\n{'='*20} STEP {step+1} {'='*20}\n")
77+
highscore = max(highscore, info.score)
8678
self.logger.info(
87-
"You managed to provide the patch that solves the task before the LLM. Congrats!"
79+
f"Step: {step} | Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
8880
)
89-
break
9081

91-
return info.done
82+
llm_done = self.try_rewrite(task_name)
83+
if llm_done:
84+
msg = f"[green]*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***[/green]"
85+
self.logger.info(msg)
86+
break
87+
else:
88+
msg = f"[red]*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***[/red]"
89+
self.logger.info(msg)
90+
91+
# If the LLM did not manage to solve the task, we continue with the guided approach.
92+
prompt = self.build_prompt(info)
93+
human_response = self.human(prompt, info.tools)
94+
95+
if debug:
96+
breakpoint()
97+
98+
# step the environment with the human response
99+
info = self.env.step(human_response.tool)
100+
# log the human response
101+
self.history.step(info, human_response)
102+
103+
if info.done:
104+
self.logger.info(
105+
"You managed to provide the patch that solves the task before the LLM. Congrats!"
106+
)
107+
# early stop, set current step and total steps to be the same
108+
self.logger.report_progress(
109+
problem_id=task_name,
110+
step=step + 1,
111+
total_steps=step + 1,
112+
score=info.score,
113+
max_score=info.max_score,
114+
status="resolved" if info.done else "unresolved",
115+
)
116+
break
117+
# keep progress bar running until max_steps is reached
118+
self.logger.report_progress(
119+
problem_id=task_name,
120+
step=step + 1,
121+
total_steps=max_steps + 1,
122+
score=info.score,
123+
max_score=info.max_score,
124+
status="running",
125+
)
126+
# max_steps was reached, task was either resolved or unresolved
127+
self.logger.report_progress(
128+
problem_id=task_name,
129+
step=step + 1,
130+
total_steps=step + 1,
131+
score=info.score,
132+
max_score=info.max_score,
133+
status="resolved" if info.done else "unresolved",
134+
)
135+
136+
return info.done
137+
except Exception:
138+
# report any error that happens during the run
139+
self.logger.report_progress(
140+
problem_id=task_name,
141+
step=step + 1,
142+
total_steps=step + 1,
143+
score=info.score if info else 0,
144+
max_score=info.max_score if info else 1,
145+
status="error",
146+
)
147+
raise

debug_gym/logger.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def __init__(
420420
log_dir: str | None = None,
421421
level: str | int = logging.INFO,
422422
mode: str = "a",
423+
icon: str = "🐸",
423424
):
424425
super().__init__(name)
425426
# If var env "DEBUG_GYM_DEBUG" is set, turn on debug mode
@@ -428,6 +429,7 @@ def __init__(
428429

429430
# Prevent the log messages from being propagated to the root logger
430431
self.propagate = False
432+
self.icon = icon # Icon to use in log messages
431433

432434
self.setLevel(level) # Set logger level, might be overridden by file handler
433435
self.log_file = None # File handler for logging to a file
@@ -451,7 +453,9 @@ def _initialize_main_logger(self, level):
451453
rich_tracebacks=True,
452454
markup=True,
453455
)
454-
rich_handler.setFormatter(logging.Formatter("🐸 [%(name)-12s]: %(message)s"))
456+
rich_handler.setFormatter(
457+
logging.Formatter(f"{self.icon} [%(name)-12s]: %(message)s")
458+
)
455459
rich_handler.setLevel(level)
456460
self.addHandler(rich_handler)
457461

0 commit comments

Comments
 (0)