Skip to content

Commit bc9dcfd

Browse files
committed
Fix logic and use human_response.tool
1 parent d1e23ba commit bc9dcfd

File tree

8 files changed

+117
-94
lines changed

8 files changed

+117
-94
lines changed

debug_gym/agents/guided_agent.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22

3+
from termcolor import colored
4+
35
from debug_gym.agents.base_agent import register_agent
46
from debug_gym.agents.rewrite_agent import RewriteAgent
57
from debug_gym.llms.base import LLM
@@ -12,6 +14,9 @@ class GuidedRewriteAgent(RewriteAgent):
1214

1315
def try_rewrite(self, task_name):
1416
# make a copy of the env for the llm
17+
from ipdb import set_trace
18+
19+
set_trace()
1520
cloned_env = self.env.clone()
1621

1722
# Only keep the rewrite tool in the cloned env
@@ -33,7 +38,10 @@ def try_rewrite(self, task_name):
3338
return info.done
3439

3540
def run(self, task_name=None, debug=False):
36-
self.llm.logger = DebugGymLogger(name="LLM", level=logging.ERROR)
41+
self.logger.level = logging.DEBUG
42+
self.llm.logger = DebugGymLogger(
43+
name="LLM", level=logging.ERROR, log_dir=self.logger.log_file.parent
44+
)
3745
self.human = LLM.instantiate(llm_name="human", logger=self.logger)
3846

3947
self.history.reset()
@@ -55,10 +63,12 @@ def run(self, task_name=None, debug=False):
5563

5664
llm_done = self.try_rewrite(task_name)
5765
if llm_done:
58-
self.logger.info(
59-
f"*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***"
60-
)
66+
msg = f"*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***"
67+
self.logger.info(colored(msg, "green"))
6168
break
69+
else:
70+
msg = f"*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***"
71+
self.logger.info(colored(msg, "red"))
6272

6373
# If the LLM did not manage to solve the task, we continue with the guided approach.
6474
prompt = self.build_prompt(info)
@@ -68,7 +78,7 @@ def run(self, task_name=None, debug=False):
6878
breakpoint()
6979

7080
# step the environment with the human response
71-
info = self.env.step(human_response.response)
81+
info = self.env.step(human_response.tool)
7282
# log the human response
7383
self.history.step(info, human_response)
7484

debug_gym/agents/solution_agent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
import subprocess
2-
31
from debug_gym.agents.base_agent import BaseAgent, register_agent
4-
from debug_gym.gym.envs.swe_bench import SWEBenchEnv
5-
from debug_gym.gym.envs.swe_smith import SWESmithEnv
62
from debug_gym.gym.tools.tool import ToolCall
73

84

debug_gym/gym/envs/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import logging
2+
13
from debug_gym.gym.envs.aider import AiderBenchmarkEnv
24
from debug_gym.gym.envs.env import RepoEnv, TooledEnv
35
from debug_gym.gym.envs.mini_nightmare import MiniNightmareEnv
4-
from debug_gym.gym.envs.swe_bench import SWEBenchEnv
5-
from debug_gym.gym.envs.swe_smith import SWESmithEnv
66

77

88
def select_env(env_type: str = None) -> type[RepoEnv]:
@@ -12,8 +12,14 @@ def select_env(env_type: str = None) -> type[RepoEnv]:
1212
case "aider":
1313
return AiderBenchmarkEnv
1414
case "swebench":
15+
from debug_gym.gym.envs.swe_bench import SWEBenchEnv
16+
17+
logging.getLogger("httpx").setLevel(logging.WARNING)
1518
return SWEBenchEnv
1619
case "swesmith":
20+
from debug_gym.gym.envs.swe_smith import SWESmithEnv
21+
22+
logging.getLogger("httpx").setLevel(logging.WARNING)
1723
return SWESmithEnv
1824
case "mini_nightmare":
1925
return MiniNightmareEnv

debug_gym/llms/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import os
32
from abc import ABC, abstractmethod
43
from dataclasses import dataclass
@@ -20,9 +19,6 @@
2019
from debug_gym.llms.utils import print_messages
2120
from debug_gym.logger import DebugGymLogger
2221

23-
# Set logging level down to WARNING for endpoint queries.
24-
logging.getLogger("httpx").setLevel(logging.WARNING)
25-
2622

2723
def retry_on_rate_limit(
2824
func, is_rate_limit_error_func, multiplier=1, max_wait=40, max_attempts=100

debug_gym/llms/human.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
import re
34
import sys
45
from typing import Any, Dict, List, Optional, Tuple
@@ -470,6 +471,12 @@ def __init__(
470471
if prompt_toolkit_available:
471472
self._history = InMemoryHistory()
472473

474+
# Warn if self.logger.level is not set at least to INFO, as this is a human interface.
475+
if self.logger.level > logging.INFO:
476+
self.logger.warning(
477+
"Human Mode should have logger level set to at least INFO (using -v) for better interaction."
478+
)
479+
473480
def tokenize(self, text: str) -> list[str]:
474481
"""Tokenizes a text by splitting it by spaces."""
475482
return text.split()

scripts/run.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import json
2+
import logging # Set logging level down to WARNING for endpoint queries.
23
import os
34
import uuid
45
from concurrent.futures import ThreadPoolExecutor, as_completed
5-
from itertools import groupby
66
from pathlib import Path
77

88
from termcolor import colored
99
from tqdm import tqdm
1010

1111
from debug_gym.agents.base_agent import AGENT_REGISTRY, create_agent
12+
13+
logging.getLogger("httpx").setLevel(logging.WARNING)
14+
1215
from debug_gym.agents.utils import load_config
1316
from debug_gym.gym.envs import select_env
1417
from debug_gym.gym.terminal import select_terminal
@@ -104,7 +107,7 @@ def main():
104107

105108
# Figure out which problems to solve.
106109
problems = config.get("problems", ["custom"])
107-
if type(problems) == str and "benchmark" in config:
110+
if type(problems) is str and "benchmark" in config:
108111
env = create_env(config, logger=logger)
109112
if problems == "all":
110113
problems = sorted(env.dataset.keys()) # all tasks

tests/agents/test_pdb_agent.py renamed to tests/agents/test_example_agent.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -95,80 +95,3 @@ def test_run_debug_5_agent(agent_setup, build_env_info):
9595
env.tools = {"pdb": MagicMock()}
9696
result = agent.run(task_name="test_task", debug=False)
9797
assert result
98-
99-
100-
@patch.object(
101-
Human,
102-
"__call__",
103-
return_value=LLMResponse(
104-
"Prompt",
105-
'{"id": "pdb-267437", "name": "pdb", "arguments": {"command": "c"}}',
106-
TokenUsage(2, 4),
107-
),
108-
)
109-
def test_human_in_the_loop(human, agent_setup, build_env_info):
110-
agent, env, llm = next(agent_setup(DebugHumanInTheLoop))
111-
env.reset.return_value = build_env_info(
112-
done=False,
113-
score=0,
114-
max_score=10,
115-
rewrite_counter=0,
116-
instructions="Test instructions",
117-
dir_tree="Test dir tree",
118-
current_breakpoints="Test breakpoints",
119-
step_observation="Test last run obs",
120-
)
121-
env.step.return_value = build_env_info(
122-
done=False,
123-
score=10,
124-
max_score=10,
125-
rewrite_counter=0,
126-
instructions="Test instructions",
127-
dir_tree="Test dir tree",
128-
current_breakpoints="Test breakpoints",
129-
step_observation="Test last run obs",
130-
)
131-
132-
env.clone.return_value = MagicMock()
133-
llm.return_value = LLMResponse("Prompt", "Expected answer", TokenUsage(2, 4))
134-
env.tools = {"pdb": MagicMock()}
135-
136-
env.clone().step.return_value = build_env_info(
137-
done=True,
138-
score=10,
139-
max_score=10,
140-
rewrite_counter=0,
141-
instructions="Test instructions",
142-
dir_tree="Test dir tree",
143-
current_breakpoints="Test breakpoints",
144-
step_observation="Test last run obs",
145-
)
146-
result = agent.run(task_name="test_task", debug=False)
147-
148-
assert result is False
149-
# test that llm actions were executed
150-
assert env.step.called
151-
env.step.assert_called_with(human().response)
152-
assert env.step().done is False
153-
154-
# test that llm actions were logged
155-
_history, _prompt_response_pairs = agent.history.get()
156-
assert [[], [human()]] == _prompt_response_pairs
157-
158-
# test that env was cloned
159-
assert env.clone.called
160-
assert env.clone().reset.called
161-
162-
# assert that cloned env was called with history steps
163-
env.clone().step.assert_has_calls(
164-
[
165-
call(agent.history.get_all()[0].action),
166-
]
167-
)
168-
169-
# test that human action was executed
170-
assert env.clone().step.called
171-
env.clone().step.assert_called_with(llm().response)
172-
173-
# ensure that human action was not recorded in history
174-
assert env.clone().step() not in agent.history.get_all()

tests/agents/test_guided_agent.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
from debug_gym.agents import GuidedRewriteAgent
4+
from debug_gym.llms import Human
5+
from debug_gym.llms.base import LLMResponse, TokenUsage
6+
7+
8+
@patch.object(
9+
Human,
10+
"__call__",
11+
return_value=LLMResponse(
12+
"Prompt",
13+
'{"id": "pdb-267437", "name": "pdb", "arguments": {"command": "c"}}',
14+
TokenUsage(2, 4),
15+
),
16+
)
17+
def test_human_in_the_loop(human, agent_setup, build_env_info):
18+
agent, env, llm = next(agent_setup(GuidedRewriteAgent))
19+
env.reset.return_value = build_env_info(
20+
done=False,
21+
score=0,
22+
max_score=10,
23+
rewrite_counter=0,
24+
instructions="Test instructions",
25+
dir_tree="Test dir tree",
26+
current_breakpoints="Test breakpoints",
27+
step_observation="Test last run obs",
28+
)
29+
env.step.return_value = build_env_info(
30+
done=False,
31+
score=10,
32+
max_score=10,
33+
rewrite_counter=0,
34+
instructions="Test instructions",
35+
dir_tree="Test dir tree",
36+
current_breakpoints="Test breakpoints",
37+
step_observation="Test last run obs",
38+
)
39+
40+
env.clone.return_value = MagicMock()
41+
llm.return_value = LLMResponse("Prompt", "Expected answer", TokenUsage(2, 4))
42+
env.tools = {"pdb": MagicMock()}
43+
44+
env.clone().step.return_value = build_env_info(
45+
done=True,
46+
score=10,
47+
max_score=10,
48+
rewrite_counter=0,
49+
instructions="Test instructions",
50+
dir_tree="Test dir tree",
51+
current_breakpoints="Test breakpoints",
52+
step_observation="Test last run obs",
53+
)
54+
result = agent.run(task_name="test_task", debug=False)
55+
56+
assert result is False
57+
# test that llm actions were executed
58+
assert env.step.called
59+
env.step.assert_called_with(human().response)
60+
assert env.step().done is False
61+
62+
# test that llm actions were logged
63+
_history, _prompt_response_pairs = agent.history.get()
64+
assert [[], [human()]] == _prompt_response_pairs
65+
66+
# test that env was cloned
67+
assert env.clone.called
68+
assert env.clone().reset.called
69+
70+
# assert that cloned env was called with history steps
71+
env.clone().step.assert_has_calls(
72+
[
73+
call(agent.history.get_all()[0].action),
74+
]
75+
)
76+
77+
# test that human action was executed
78+
assert env.clone().step.called
79+
env.clone().step.assert_called_with(llm().response)
80+
81+
# ensure that human action was not recorded in history
82+
assert env.clone().step() not in agent.history.get_all()

0 commit comments

Comments
 (0)