Skip to content

Commit e5abb91

Browse files
committed
Move guided agent in its own file.
1 parent dd32232 commit e5abb91

File tree

7 files changed

+84
-82
lines changed

7 files changed

+84
-82
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,6 @@ logs/
180180

181181
/data
182182

183-
.vscode/
184-
185183
vscode/out
186184
vscode/node_modules
187185
vscode/package-lock.json

debug_gym/agents/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent
2+
from debug_gym.agents.guided_agent import GuidedRewriteAgent
23
from debug_gym.agents.rewrite_agent import RewriteAgent

debug_gym/agents/debug_agent.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from debug_gym.agents.base_agent import BaseAgent, register_agent
2-
from debug_gym.llms.base import LLM
32

43

54
@register_agent
@@ -63,79 +62,3 @@ def run(self, task_name=None, debug=False):
6362
break
6463

6564
return info.done
66-
67-
68-
@register_agent
69-
class DebugHumanInTheLoop(DebugAgent):
70-
name: str = "debug_human"
71-
72-
def run(self, task_name=None, debug=False):
73-
# instantiate the human in the loop
74-
self.human = LLM.instantiate(
75-
llm_name="human",
76-
llm_config_file_path=self.config.get("llm_config_file_path"),
77-
logger=self.logger,
78-
)
79-
80-
self.history.reset()
81-
info = self.env.reset(options={"task_name": task_name})
82-
# initial state does not have prompt and response
83-
self.history.step(info, None)
84-
85-
if info.done is True:
86-
# msg = "Environment started with entrypoint passing without errors."
87-
return True
88-
89-
highscore = info.score
90-
91-
for step in self.logger.tqdm(range(self.config["max_steps"])):
92-
highscore = max(highscore, info.score)
93-
self.logger.info(
94-
f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
95-
)
96-
97-
prompt = self.build_prompt(info)
98-
99-
human_response = self.human(prompt, info.tools)
100-
101-
if debug:
102-
breakpoint()
103-
104-
# make a copy of the env for the llm
105-
self.cloned_env = self.env.clone()
106-
# remove the pdb tool from the cloned env
107-
if self.cloned_env.has_tool("pdb"):
108-
self.cloned_env.remove_tool("pdb")
109-
llm_info = self.cloned_env.reset(options={"task_name": task_name})
110-
# replay the history up to the current step
111-
for step in self.history.get_all():
112-
if step.done:
113-
break
114-
llm_info = self.cloned_env.step(step.action)
115-
116-
# step the environment with the human response
117-
info = self.env.step(human_response.response)
118-
# log the human response
119-
self.history.step(info, human_response)
120-
121-
if info.done or info.rewrite_counter >= self.config["max_rewrite_steps"]:
122-
self.logger.info(
123-
f"Score (human): {info.score}/{info.max_score} ({info.score/info.max_score:.1%})"
124-
)
125-
break
126-
127-
# call the llm with the cloned environment
128-
prompt = self.build_prompt(llm_info)
129-
llm_response = self.llm(prompt, llm_info.tools)
130-
llm_info = self.cloned_env.step(llm_response.response)
131-
132-
if (
133-
llm_info.done
134-
or llm_info.rewrite_counter >= self.config["max_rewrite_steps"]
135-
):
136-
self.logger.info(
137-
f"Score (llm): {llm_info.score}/{llm_info.max_score} ({llm_info.score/llm_info.max_score:.1%})"
138-
)
139-
break
140-
141-
return info.done

debug_gym/agents/guided_agent.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import logging
2+
3+
from debug_gym.agents.base_agent import register_agent
4+
from debug_gym.agents.rewrite_agent import RewriteAgent
5+
from debug_gym.llms.base import LLM
6+
from debug_gym.logger import DebugGymLogger
7+
8+
9+
@register_agent
10+
class GuidedRewriteAgent(RewriteAgent):
11+
name: str = "guided_agent"
12+
13+
def try_rewrite(self, task_name):
14+
# make a copy of the env for the llm
15+
cloned_env = self.env.clone()
16+
17+
# Only keep the rewrite tool in the cloned env
18+
for tool in cloned_env.tools:
19+
if tool.name != "rewrite":
20+
cloned_env.remove_tool(tool.name)
21+
22+
# Reset the cloned environment and replay the history.
23+
info = cloned_env.reset(options={"task_name": task_name})
24+
# replay the history up to the current step
25+
for step in self.history.get_all():
26+
assert not step.done
27+
info = cloned_env.step(step.action)
28+
29+
prompt = self.build_prompt(info)
30+
response = self.llm(prompt, info.tools)
31+
info = cloned_env.step(response.response)
32+
33+
return info.done
34+
35+
def run(self, task_name=None, debug=False):
36+
self.llm.logger = DebugGymLogger(name="LLM", level=logging.ERROR)
37+
self.human = LLM.instantiate(llm_name="human", logger=self.logger)
38+
39+
self.history.reset()
40+
info = self.env.reset(options={"task_name": task_name})
41+
# initial state does not have prompt and response
42+
self.history.step(info, None)
43+
44+
if info.done is True:
45+
# msg = "Environment started with entrypoint passing without errors."
46+
return True
47+
48+
highscore = info.score
49+
50+
for step in self.logger.tqdm(range(self.config["max_steps"])):
51+
highscore = max(highscore, info.score)
52+
self.logger.info(
53+
f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
54+
)
55+
56+
llm_done = self.try_rewrite(task_name)
57+
if llm_done:
58+
self.logger.info(
59+
f"*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***"
60+
)
61+
break
62+
63+
# If the LLM did not manage to solve the task, we continue with the guided approach.
64+
prompt = self.build_prompt(info)
65+
human_response = self.human(prompt, info.tools)
66+
67+
if debug:
68+
breakpoint()
69+
70+
# step the environment with the human response
71+
info = self.env.step(human_response.response)
72+
# log the human response
73+
self.history.step(info, human_response)
74+
75+
if info.done:
76+
self.logger.info(
77+
"You managed to provide the patch that solves the task before the LLM. Congrats!"
78+
)
79+
break
80+
81+
return info.done

debug_gym/agents/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def load_config():
159159
nargs="+",
160160
action="extend",
161161
metavar="my.setting=value",
162-
action="extend",
163162
default=[],
164163
help="override params of the config file,"
165164
" e.g. -p 'rewrite_only.random_seed=123'",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ dev = [
3030
"pytest-xdist",
3131
"pytest-timeout",
3232
"pytest-env",
33-
]
33+
]

scripts/config_mini_nightmare.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ debug_5_agent:
4343
n_rewrites_before_pdb: 5
4444
tools: ["pdb", "view", "rewrite", "eval"]
4545

46-
debug_human:
46+
guided_agent:
4747
tools: ["pdb", "view", "rewrite", "eval"]

0 commit comments

Comments
 (0)