Skip to content

Commit fa9d136

Browse files
committed
Fix logic and use human_response.tool
1 parent d1e23ba commit fa9d136

File tree

8 files changed

+133
-18
lines changed

8 files changed

+133
-18
lines changed

debug_gym/agents/guided_agent.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22

3+
from termcolor import colored
4+
35
from debug_gym.agents.base_agent import register_agent
46
from debug_gym.agents.rewrite_agent import RewriteAgent
57
from debug_gym.llms.base import LLM
@@ -12,6 +14,9 @@ class GuidedRewriteAgent(RewriteAgent):
1214

1315
def try_rewrite(self, task_name):
1416
# make a copy of the env for the llm
17+
from ipdb import set_trace
18+
19+
set_trace()
1520
cloned_env = self.env.clone()
1621

1722
# Only keep the rewrite tool in the cloned env
@@ -33,7 +38,10 @@ def try_rewrite(self, task_name):
3338
return info.done
3439

3540
def run(self, task_name=None, debug=False):
36-
self.llm.logger = DebugGymLogger(name="LLM", level=logging.ERROR)
41+
self.logger.level = logging.DEBUG
42+
self.llm.logger = DebugGymLogger(
43+
name="LLM", level=logging.ERROR, log_dir=self.logger.log_file.parent
44+
)
3745
self.human = LLM.instantiate(llm_name="human", logger=self.logger)
3846

3947
self.history.reset()
@@ -55,10 +63,12 @@ def run(self, task_name=None, debug=False):
5563

5664
llm_done = self.try_rewrite(task_name)
5765
if llm_done:
58-
self.logger.info(
59-
f"*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***"
60-
)
66+
msg = f"*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***"
67+
self.logger.info(colored(msg, "green"))
6168
break
69+
else:
70+
msg = f"*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***"
71+
self.logger.info(colored(msg, "red"))
6272

6373
# If the LLM did not manage to solve the task, we continue with the guided approach.
6474
prompt = self.build_prompt(info)
@@ -68,7 +78,7 @@ def run(self, task_name=None, debug=False):
6878
breakpoint()
6979

7080
# step the environment with the human response
71-
info = self.env.step(human_response.response)
81+
info = self.env.step(human_response.tool)
7282
# log the human response
7383
self.history.step(info, human_response)
7484

debug_gym/agents/solution_agent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
import subprocess
2-
31
from debug_gym.agents.base_agent import BaseAgent, register_agent
4-
from debug_gym.gym.envs.swe_bench import SWEBenchEnv
5-
from debug_gym.gym.envs.swe_smith import SWESmithEnv
62
from debug_gym.gym.tools.tool import ToolCall
73

84

debug_gym/gym/envs/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import logging
2+
13
from debug_gym.gym.envs.aider import AiderBenchmarkEnv
24
from debug_gym.gym.envs.env import RepoEnv, TooledEnv
35
from debug_gym.gym.envs.mini_nightmare import MiniNightmareEnv
4-
from debug_gym.gym.envs.swe_bench import SWEBenchEnv
5-
from debug_gym.gym.envs.swe_smith import SWESmithEnv
66

77

88
def select_env(env_type: str = None) -> type[RepoEnv]:
@@ -12,8 +12,14 @@ def select_env(env_type: str = None) -> type[RepoEnv]:
1212
case "aider":
1313
return AiderBenchmarkEnv
1414
case "swebench":
15+
from debug_gym.gym.envs.swe_bench import SWEBenchEnv
16+
17+
logging.getLogger("httpx").setLevel(logging.WARNING)
1518
return SWEBenchEnv
1619
case "swesmith":
20+
from debug_gym.gym.envs.swe_smith import SWESmithEnv
21+
22+
logging.getLogger("httpx").setLevel(logging.WARNING)
1723
return SWESmithEnv
1824
case "mini_nightmare":
1925
return MiniNightmareEnv

debug_gym/llms/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import os
32
from abc import ABC, abstractmethod
43
from dataclasses import dataclass
@@ -20,9 +19,6 @@
2019
from debug_gym.llms.utils import print_messages
2120
from debug_gym.logger import DebugGymLogger
2221

23-
# Set logging level down to WARNING for endpoint queries.
24-
logging.getLogger("httpx").setLevel(logging.WARNING)
25-
2622

2723
def retry_on_rate_limit(
2824
func, is_rate_limit_error_func, multiplier=1, max_wait=40, max_attempts=100

debug_gym/llms/human.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
import re
34
import sys
45
from typing import Any, Dict, List, Optional, Tuple
@@ -470,6 +471,12 @@ def __init__(
470471
if prompt_toolkit_available:
471472
self._history = InMemoryHistory()
472473

474+
# Warn if self.logger.level is not set at least to INFO, as this is a human interface.
475+
if self.logger.level > logging.INFO:
476+
self.logger.warning(
477+
"Human Mode should have logger level set to at least INFO (using -v) for better interaction."
478+
)
479+
473480
def tokenize(self, text: str) -> list[str]:
474481
"""Tokenizes a text by splitting it by spaces."""
475482
return text.split()

scripts/run.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import json
2+
import logging # Set logging level down to WARNING for endpoint queries.
23
import os
34
import uuid
45
from concurrent.futures import ThreadPoolExecutor, as_completed
5-
from itertools import groupby
66
from pathlib import Path
77

88
from termcolor import colored
99
from tqdm import tqdm
1010

1111
from debug_gym.agents.base_agent import AGENT_REGISTRY, create_agent
12+
13+
logging.getLogger("httpx").setLevel(logging.WARNING)
14+
1215
from debug_gym.agents.utils import load_config
1316
from debug_gym.gym.envs import select_env
1417
from debug_gym.gym.terminal import select_terminal
@@ -104,7 +107,7 @@ def main():
104107

105108
# Figure out which problems to solve.
106109
problems = config.get("problems", ["custom"])
107-
if type(problems) == str and "benchmark" in config:
110+
if type(problems) is str and "benchmark" in config:
108111
env = create_env(config, logger=logger)
109112
if problems == "all":
110113
problems = sorted(env.dataset.keys()) # all tasks

tests/agents/test_example_agent.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from unittest.mock import MagicMock
2+
3+
from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent
4+
from debug_gym.agents.rewrite_agent import RewriteAgent
5+
from debug_gym.llms.base import LLMResponse, TokenUsage
6+
7+
8+
def test_build_question_prompt(agent_setup):
9+
agent, _, _ = next(agent_setup(DebugAgent))
10+
messages = agent.build_question_prompt()
11+
assert len(messages) == 1
12+
assert "continue your debugging" in messages[0]["content"]
13+
14+
15+
def test_build_prompt(agent_setup, build_env_info):
16+
agent, _, _ = next(agent_setup(DebugAgent))
17+
info = build_env_info(
18+
instructions="Test instructions",
19+
dir_tree="Test dir tree",
20+
current_breakpoints="Test breakpoints",
21+
step_observation="Test last run obs",
22+
)
23+
messages = agent.build_prompt(info)
24+
assert len(messages) > 0
25+
26+
27+
def test_run(agent_setup, build_env_info):
28+
agent, env, llm = next(agent_setup(DebugAgent))
29+
env.reset.return_value = build_env_info(
30+
done=False,
31+
score=0,
32+
max_score=10,
33+
instructions="Test instructions",
34+
dir_tree="Test dir tree",
35+
current_breakpoints="Test breakpoints",
36+
step_observation="Test last run obs",
37+
)
38+
env.step.return_value = build_env_info(
39+
done=True,
40+
score=10,
41+
max_score=10,
42+
instructions="Test instructions",
43+
dir_tree="Test dir tree",
44+
current_breakpoints="Test breakpoints",
45+
step_observation="Test last run obs",
46+
)
47+
llm.return_value = LLMResponse("Prompt", "Expected answer", TokenUsage(2, 4))
48+
result = agent.run(task_name="test_task", debug=False)
49+
assert result
50+
51+
52+
def test_build_system_prompt_rewrite_agent(agent_setup, build_env_info):
53+
agent, _, _ = next(agent_setup(RewriteAgent))
54+
info = build_env_info(
55+
instructions="Test instructions",
56+
dir_tree="Test dir tree",
57+
current_breakpoints="Test breakpoints",
58+
step_observation="Test last run obs",
59+
)
60+
messages = agent.build_system_prompt(info)
61+
assert len(messages) == 1
62+
assert "Overall task" in messages[0]["content"]
63+
64+
65+
def test_build_question_prompt_rewrite_agent(agent_setup):
66+
agent, _, _ = next(agent_setup(RewriteAgent))
67+
messages = agent.build_question_prompt()
68+
assert len(messages) == 1
69+
assert "continue your debugging" in messages[0]["content"]
70+
71+
72+
def test_run_debug_5_agent(agent_setup, build_env_info):
73+
agent, env, llm = next(agent_setup(Debug_5_Agent))
74+
env.reset.return_value = build_env_info(
75+
done=False,
76+
score=0,
77+
max_score=10,
78+
rewrite_counter=0,
79+
instructions="Test instructions",
80+
dir_tree="Test dir tree",
81+
current_breakpoints="Test breakpoints",
82+
step_observation="Test last run obs",
83+
)
84+
env.step.return_value = build_env_info(
85+
done=True,
86+
score=10,
87+
max_score=10,
88+
rewrite_counter=0,
89+
instructions="Test instructions",
90+
dir_tree="Test dir tree",
91+
current_breakpoints="Test breakpoints",
92+
step_observation="Test last run obs",
93+
)
94+
llm.return_value = LLMResponse("Prompt", "Expected answer", TokenUsage(2, 4))
95+
env.tools = {"pdb": MagicMock()}
96+
result = agent.run(task_name="test_task", debug=False)
97+
assert result

tests/agents/test_pdb_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import MagicMock
1+
from unittest.mock import MagicMock, patch
22

33
from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent
44
from debug_gym.agents.rewrite_agent import RewriteAgent

0 commit comments

Comments
 (0)