diff --git a/debug_gym/agents/__init__.py b/debug_gym/agents/__init__.py index 83161b49..5814ce22 100644 --- a/debug_gym/agents/__init__.py +++ b/debug_gym/agents/__init__.py @@ -1,3 +1,4 @@ from debug_gym.agents.debug_agent import Debug_5_Agent, DebugAgent +from debug_gym.agents.guided_agent import GuidedRewriteAgent from debug_gym.agents.rewrite_agent import RewriteAgent from debug_gym.agents.solution_agent import AgentSolution diff --git a/debug_gym/agents/guided_agent.py b/debug_gym/agents/guided_agent.py new file mode 100644 index 00000000..4e967ff4 --- /dev/null +++ b/debug_gym/agents/guided_agent.py @@ -0,0 +1,184 @@ +import logging + +from debug_gym.agents.base_agent import register_agent +from debug_gym.agents.history_tracker import build_history_prompt +from debug_gym.agents.rewrite_agent import RewriteAgent +from debug_gym.gym.entities import Event +from debug_gym.gym.tools.tool import ToolCall +from debug_gym.gym.tools.toolbox import Toolbox +from debug_gym.llms.base import LLM +from debug_gym.logger import DebugGymLogger + + +@register_agent +class GuidedRewriteAgent(RewriteAgent): + name: str = "guided_agent" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Initialize the different LLM rewriters. + self.llms = [ + LLM.instantiate( + llm_name=llm_name, + logger=DebugGymLogger( + name=llm_name, + level=logging.DEBUG, + log_dir=self.logger.log_file.parent, + icon="🤖", + ), + ) + for llm_name in self.config["llms"] + ] + + # Create logger for the main guide, e.g. (a human). + self.llm.logger = DebugGymLogger( + name=self.config["llm_name"], + level=logging.DEBUG, + log_dir=self.logger.log_file.parent, + icon="👤", + ) + + def build_prompt(self, info, llm): + messages = [] + messages.extend(self.build_system_prompt(info)) + messages.extend(self.build_history_prompt(llm)) + messages.extend(self.build_question_prompt()) + return messages + + def build_history_prompt(self, llm): + messages = build_history_prompt( + self.history, + llm, + self.config["reset_prompt_history_after_rewrite"], + ) + return messages + + def try_rewrite_and_rollback(self, llm, last_info): + prompt = self.build_prompt(last_info, llm) + + # Git commit the current state before trying to rewrite. + self.env.terminal.run("git add . && git commit -m 'Before rewrite attempt'") + + # Remove all tools except the rewrite tool. + tools = [tool for tool in last_info.tools if tool.name == "rewrite"] + response = llm(prompt, tools) + llm.logger.info(f"LLM response: {response.response}") + llm.logger.info(f"LLM tool: {response.tool}") + + # Temporarily disable the REWRITE_SUCCESS event. + self.env.event_hooks.mute(Event.REWRITE_SUCCESS) + info_after_rewrite = self.env.step(response.tool) + llm_info = self.env.step(ToolCall(id="eval", name="eval", arguments={})) + self.env.event_hooks.unmute(Event.REWRITE_SUCCESS) + + llm.logger.info(f"LLM observation: {llm_info.eval_observation.observation}.") + + if not llm_info.done: + # Rollback any changes made by the LLM if it hasn't solved the task yet. + self.env.terminal.run("git reset --hard HEAD") + + return llm_info + + def run(self, task_name=None, debug=False): + step = 0 + max_steps = self.config["max_steps"] + info = None + llm_done = False + try: + self.history.reset() + info = self.env.reset(options={"task_name": task_name}) + # initial state does not have prompt and response + self.history.step(info, None) + + # First make sure git is setup correctly. + self.env.terminal.run( + "git init && git config user.name 'debug-gym' && git config user.email '<>'" + ) + + if info.done is True: + self.logger.report_progress( + problem_id=task_name, + step=1, + total_steps=1, + score=info.score, + max_score=info.max_score, + status="resolved", + ) + return True + + highscore = info.score + + for step in range(max_steps): + self.logger.info(f"\n{'='*20} STEP {step+1} {'='*20}\n") + highscore = max(highscore, info.score) + self.logger.info( + f"Step: {step} | Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]" + ) + + solved = None + for llm in self.llms: + llm_info = self.try_rewrite_and_rollback(llm, info) + if llm_info.done: + solved = llm_info + msg = f"[green] ✅ The rewrite-only agent with {llm.model_name} managed to solve the task with the current context. ✅ [/green]" + llm.logger.info(msg) + else: + msg = f"[red] ❌ The rewrite-only agent with {llm.model_name} failed to solve the task with the current context. ❌ [/red]" + llm.logger.info(msg) + + if solved is not None: + llm_info = solved + break + + # If the LLM did not manage to solve the task, we continue with the guided approach. + prompt = self.build_prompt(info, self.llm) + guide_response = self.llm(prompt, info.tools) + + if debug: + breakpoint() + + # step the environment with the guide response + info = self.env.step(guide_response.tool) + # log the guide response + self.history.step(info, guide_response) + + if info.done: + self.logger.info( + "You managed to provide the patch that solves the task before the LLM. Congrats!" + ) + break + + # keep progress bar running until max_steps is reached + self.logger.report_progress( + problem_id=task_name, + step=step + 1, + total_steps=max_steps + 1, + score=info.score, + max_score=info.max_score, + status="running", + ) + + # max_steps was reached, task was either resolved or unresolved + # self.logger.report_progress( + # problem_id=task_name, + # step=step + 1, + # total_steps=step + 1, + # score=info.score, + # max_score=info.max_score, + # status="resolved" if info.done or llm_info.done else "unresolved", + # ) + + return info.done or llm_info.done + except Exception as e: + # report any error that happens during the run + if info: + self.logger.report_progress( + problem_id=task_name, + step=step + 1, + total_steps=step + 1, + score=info.score if info else 0, + max_score=info.max_score if info else 1, + status="error", + ) + raise diff --git a/debug_gym/gym/envs/__init__.py b/debug_gym/gym/envs/__init__.py index fb4000b1..769c7a97 100644 --- a/debug_gym/gym/envs/__init__.py +++ b/debug_gym/gym/envs/__init__.py @@ -1,8 +1,8 @@ +import logging + from debug_gym.gym.envs.aider import AiderBenchmarkEnv from debug_gym.gym.envs.env import RepoEnv, TooledEnv from debug_gym.gym.envs.mini_nightmare import MiniNightmareEnv -from debug_gym.gym.envs.swe_bench import SWEBenchEnv -from debug_gym.gym.envs.swe_smith import SWESmithEnv def select_env(env_type: str = None) -> type[RepoEnv]: @@ -12,8 +12,14 @@ def select_env(env_type: str = None) -> type[RepoEnv]: case "aider": return AiderBenchmarkEnv case "swebench": + from debug_gym.gym.envs.swe_bench import SWEBenchEnv + + logging.getLogger("httpx").setLevel(logging.WARNING) return SWEBenchEnv case "swesmith": + from debug_gym.gym.envs.swe_smith import SWESmithEnv + + logging.getLogger("httpx").setLevel(logging.WARNING) return SWESmithEnv case "mini_nightmare": return MiniNightmareEnv diff --git a/debug_gym/gym/envs/aider.py b/debug_gym/gym/envs/aider.py index f47be39d..3fc1d188 100644 --- a/debug_gym/gym/envs/aider.py +++ b/debug_gym/gym/envs/aider.py @@ -5,12 +5,24 @@ from debug_gym.constants import DEBUG_GYM_CACHE_DIR from debug_gym.gym.entities import EvalOutput from debug_gym.gym.envs.env import RepoEnv +from debug_gym.gym.terminal import DockerTerminal, Terminal class AiderBenchmarkEnv(RepoEnv): REPO_URL = "https://github.com/exercism/python" REPO_PATH = DEBUG_GYM_CACHE_DIR / "exercism" + def __init__( + self, + terminal: Terminal | None = None, + **kwargs, + ): + terminal = terminal or DockerTerminal(logger=kwargs.get("logger")) + if not isinstance(terminal, DockerTerminal): + raise ValueError("AiderBenchmarkEnv only supports DockerTerminal.") + + super().__init__(terminal=terminal, **kwargs) + @property def instructions(self) -> str: return self.current_sample["instructions"] @@ -31,11 +43,29 @@ def eval(self, **kwargs) -> EvalOutput: self.last_eval = EvalOutput(success, output) return self.last_eval + def setup_terminal(self): + self.logger.info(f"Configuring docker container: {self.terminal.container}") + + self.terminal.run("git init") + self.terminal.run("git config user.name 'debug-gym'") + self.terminal.run("git config user.email '<>'") + + self.terminal.run("git add *.py") + self.terminal.run("git commit -am 'Init'") + + self.terminal.run("git add .debugignore") + self.terminal.run("git add .debugreadonly") + self.terminal.run("git commit -am 'Add debug-gym ignore and read-only files'") + def reset(self, *, options: dict = None): options = options or {} self.current_sample = self.dataset[options["task_name"]] directory = self.current_sample["base_directory"] self.setup_workspace(directory, entrypoint=self.entrypoint) + from ipdb import set_trace + + set_trace() + self.setup_terminal() infos = super().reset(options=options) return infos diff --git a/debug_gym/gym/envs/env.py b/debug_gym/gym/envs/env.py index 6cc419db..8f25891a 100644 --- a/debug_gym/gym/envs/env.py +++ b/debug_gym/gym/envs/env.py @@ -37,6 +37,7 @@ class EnvInfo: class EventHooks: def __init__(self): self.event_listeners = {event: [] for event in Event} + self.event_listeners_muted = {event: [] for event in Event} def subscribe(self, event: Event, tool: "Tool"): if event not in self.event_listeners: @@ -50,6 +51,20 @@ def subscribe(self, event: Event, tool: "Tool"): def unsubscribe(self, event: Event, tool): self.event_listeners[event].remove(tool) + def mute(self, event: Event): + """Mute all tools for the given event.""" + if event not in self.event_listeners_muted: + raise ValueError(f"Unknown event type: {event}") + self.event_listeners_muted[event] = self.event_listeners[event][:] + self.event_listeners[event] = [] + + def unmute(self, event: Event): + """Unmute all tools for the given event.""" + if event not in self.event_listeners_muted: + raise ValueError(f"Unknown event type: {event}") + self.event_listeners[event] = self.event_listeners_muted[event][:] + self.event_listeners_muted[event] = [] + def notify( self, environment, event: Event, source=None, **kwargs ) -> list[Observation]: @@ -500,10 +515,12 @@ def current_breakpoints(self): @property def patch(self): - command = ["git", "diff", "--no-index", self.path, self.working_dir] - result = subprocess.run(command, text=True, capture_output=True) - patch = result.stdout.replace(str(self.working_dir), str(self.path)) - return patch + success, output = self.terminal.run("git diff") + if not success: + self.logger.error("Failed to get git diff. {output}") + return None + + return output def apply_gold_patch(self): raise NotImplementedError( diff --git a/debug_gym/gym/envs/mini_nightmare.py b/debug_gym/gym/envs/mini_nightmare.py index 6ba98fec..54811d70 100644 --- a/debug_gym/gym/envs/mini_nightmare.py +++ b/debug_gym/gym/envs/mini_nightmare.py @@ -1,9 +1,11 @@ import os +import subprocess from os.path import join as pjoin import debug_gym.gym.utils as utils from debug_gym.gym.entities import EvalOutput from debug_gym.gym.envs.env import RepoEnv +from debug_gym.gym.terminal import DockerTerminal, Terminal class MiniNightmareEnv(RepoEnv): @@ -21,6 +23,17 @@ class MiniNightmareEnv(RepoEnv): "tomorrow_date", ] + def __init__( + self, + terminal: Terminal | None = None, + **kwargs, + ): + terminal = terminal or DockerTerminal(logger=kwargs.get("logger")) + if not isinstance(terminal, DockerTerminal): + raise ValueError("MiniNightmareEnv only supports DockerTerminal.") + + super().__init__(terminal=terminal, **kwargs) + @property def instructions(self) -> str: return self.current_sample["instructions"] @@ -41,11 +54,27 @@ def eval(self, **kwargs) -> EvalOutput: self.last_eval = EvalOutput(success, output) return self.last_eval + def setup_terminal(self): + self.logger.info(f"Configuring {self.terminal.container}...") + + self.terminal.run("git init") + self.terminal.run("git config user.name 'debug-gym'") + self.terminal.run("git config user.email '<>'") + + self.terminal.run("git add *.py") + self.terminal.run("git commit -am 'Init'") + + self.terminal.run("git add .debugignore") + self.terminal.run("git add .debugreadonly") + self.terminal.run("git commit -am 'Add debug-gym ignore and read-only files'") + def reset(self, *, options: dict = None): options = options or {} self.current_sample = self.dataset[options["task_name"]] directory = self.current_sample["base_directory"] self.setup_workspace(directory, entrypoint=self.entrypoint) + self.setup_terminal() + infos = super().reset(options=options) return infos diff --git a/debug_gym/gym/envs/swe_bench.py b/debug_gym/gym/envs/swe_bench.py index 719cf556..873a9e58 100644 --- a/debug_gym/gym/envs/swe_bench.py +++ b/debug_gym/gym/envs/swe_bench.py @@ -147,15 +147,6 @@ def setup_task(self, task_name): self.test_spec, docker.from_env(), logger=None, nocache=False ) - @property - def patch(self): - command = "git diff" - result = subprocess.run( - command.split(), cwd=self.working_dir, text=True, capture_output=True - ) - # patch = result.stdout.replace(str(self.working_dir), str(self.path)) - return result.stdout - def apply_gold_patch(self): self.logger.info(f"Applying gold patch to {self.working_dir}.") command = self.git_apply_cmd + f" <<'EOF'\n{self.gold_patch}\nEOF" diff --git a/debug_gym/gym/terminal.py b/debug_gym/gym/terminal.py index 64fe6ea4..1fb75c87 100644 --- a/debug_gym/gym/terminal.py +++ b/debug_gym/gym/terminal.py @@ -457,7 +457,7 @@ def setup_container(self) -> docker.models.containers.Container: container.rename(container_name) container.reload() self._run_setup_commands(container) - self.logger.debug(f"Container {container_name} started successfully.") + self.logger.debug(f"{container} ({container_name}) started successfully.") atexit.register(self.clean_up) return container @@ -465,7 +465,7 @@ def _run_setup_commands(self, container): """Run setup commands if any. If commands fail, stop the container.""" if self.setup_commands: setup_commands = " && ".join(self.setup_commands) - self.logger.debug(f"Running setup commands: {setup_commands}") + self.logger.debug(f"{container} Running setup commands: {setup_commands}") status, output = container.exec_run( ["/bin/bash", "-c", setup_commands], user="root", # Run as root to allow installations diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index 2682db3d..50582f1c 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -1,4 +1,3 @@ -import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass @@ -19,9 +18,6 @@ from debug_gym.llms.utils import print_messages from debug_gym.logger import DebugGymLogger -# Set logging level down to WARNING for endpoint queries. -logging.getLogger("httpx").setLevel(logging.WARNING) - def retry_on_exception( func, exception_filter_func, multiplier=1, max_wait=40, max_attempts=100 diff --git a/debug_gym/llms/human.py b/debug_gym/llms/human.py index 088ac49d..a1e53e3e 100644 --- a/debug_gym/llms/human.py +++ b/debug_gym/llms/human.py @@ -1,4 +1,5 @@ import json +import logging import re import sys from typing import Any, Dict, List, Optional, Tuple @@ -471,6 +472,12 @@ def __init__( if prompt_toolkit_available: self._history = InMemoryHistory() + # Warn if self.logger.level is not set at least to INFO, as this is a human interface. + if self.logger.level > logging.INFO: + self.logger.warning( + "Human Mode should have logger level set to at least INFO (using -v) for better interaction." + ) + def tokenize(self, text: str) -> list[str]: """Tokenizes a text by splitting it by spaces.""" return text.split() diff --git a/debug_gym/logger.py b/debug_gym/logger.py index a0525c34..d41e0c80 100644 --- a/debug_gym/logger.py +++ b/debug_gym/logger.py @@ -405,6 +405,20 @@ def _status_listener(self): self.logger.debug("Status listener thread exiting...") +class IconFilter(logging.Filter): + def __init__(self, *args, icon="🐸", **kwargs): + super().__init__(*args, **kwargs) + self.icon = icon + + def filter(self, record): + if not hasattr(record, "icon"): + # If the record does not have an icon attribute, set it + # This allows the icon to be used in log messages + record.icon = self.icon + + return True + + class DebugGymLogger(logging.Logger): """A multiprocess friendly logger that integrates with Rich for progress reporting. Multiprocess workers can use this logger to log messages and report progress via @@ -420,6 +434,7 @@ def __init__( log_dir: str | None = None, level: str | int = logging.INFO, mode: str = "a", + icon: str = "🐸", ): super().__init__(name) # If var env "DEBUG_GYM_DEBUG" is set, turn on debug mode @@ -428,6 +443,8 @@ def __init__( # Prevent the log messages from being propagated to the root logger self.propagate = False + self.icon_filter = IconFilter(icon=icon) + self.addFilter(self.icon_filter) self.setLevel(level) # Set logger level, might be overridden by file handler self.log_file = None # File handler for logging to a file @@ -443,6 +460,16 @@ def __init__( if log_dir: self._initialize_file_handler(name, log_dir, mode) + @property + def icon(self): + """Get the icon used in log messages.""" + return self.icon_filter.icon + + @icon.setter + def icon(self, icon: str): + """Set the icon for the logger. This will update the icon used in log messages.""" + self.icon_filter.icon = icon + def _initialize_main_logger(self, level): self._live = Live(transient=True, refresh_per_second=2) rich_handler = RichHandler( @@ -451,7 +478,9 @@ def _initialize_main_logger(self, level): rich_tracebacks=True, markup=True, ) - rich_handler.setFormatter(logging.Formatter("🐸 [%(name)-12s]: %(message)s")) + rich_handler.setFormatter( + logging.Formatter(r"%(icon)s \[%(name)-12s]: %(message)s") + ) rich_handler.setLevel(level) self.addHandler(rich_handler) @@ -481,6 +510,7 @@ def handle(self, record): record into the log queue for the main process to display logs through Rich.""" if self._is_worker: + # record.args.append(self.icon) self.LOG_QUEUE.put(record) super().handle(record) diff --git a/scripts/config_aider.yaml b/scripts/config_aider.yaml index 9a5ec0ee..82d9cb56 100644 --- a/scripts/config_aider.yaml +++ b/scripts/config_aider.yaml @@ -20,7 +20,7 @@ base: # session_commands define commands that are always executed before starting a shell session or running a single command in the terminal. # session_commands:["conda activate aider"], # setup_commands define commands that are executed only once when the terminal is created. This is only supported for Docker terminal. - setup_commands: ["pip install pytest"], + setup_commands: ["apt update", "apt install -y git", "pip install pytest"], } # LLM configs diff --git a/scripts/config_mini_nightmare.yaml b/scripts/config_mini_nightmare.yaml index 747c2332..db930e85 100644 --- a/scripts/config_mini_nightmare.yaml +++ b/scripts/config_mini_nightmare.yaml @@ -20,7 +20,7 @@ base: # session_commands define commands that are always executed before starting a shell session or running a single command in the terminal. # session_commands:["conda activate aider"], # setup_commands define commands that are executed only once when the terminal is created. This is only supported for Docker terminal. - setup_commands: ["pip install pytest pandas"], + setup_commands: ["apt update", "apt install -y git", "pip install pytest pandas"], } # LLM configs @@ -45,3 +45,8 @@ debug_agent: debug_5_agent: n_rewrites_before_pdb: 5 tools: ["pdb", "view", "rewrite", "eval"] + +guided_agent: + llm_name: "gpt-4o" + tools: ["pdb", "view", "rewrite", "eval"] + llms: ["gpt-4o", "gpt-4o"] diff --git a/scripts/run.py b/scripts/run.py index 925a6091..1ef28843 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -1,5 +1,6 @@ import datetime import json +import logging # Set logging level down to WARNING for endpoint queries. import os import signal import subprocess @@ -9,6 +10,9 @@ from debug_gym import version as dg_version from debug_gym.agents.base_agent import AGENT_REGISTRY, create_agent + +logging.getLogger("httpx").setLevel(logging.WARNING) + from debug_gym.agents.utils import load_config from debug_gym.gym.envs import select_env from debug_gym.gym.terminal import select_terminal diff --git a/tests/agents/test_guided_agent.py b/tests/agents/test_guided_agent.py new file mode 100644 index 00000000..7e7b6721 --- /dev/null +++ b/tests/agents/test_guided_agent.py @@ -0,0 +1,82 @@ +from unittest.mock import MagicMock, patch + +from debug_gym.agents import GuidedRewriteAgent +from debug_gym.llms import Human +from debug_gym.llms.base import LLMResponse, TokenUsage + + +@patch.object( + Human, + "__call__", + return_value=LLMResponse( + "Prompt", + '{"id": "pdb-267437", "name": "pdb", "arguments": {"command": "c"}}', + TokenUsage(2, 4), + ), +) +def test_human_in_the_loop(human, agent_setup, build_env_info): + agent, env, llm = next(agent_setup(GuidedRewriteAgent)) + env.reset.return_value = build_env_info( + done=False, + score=0, + max_score=10, + rewrite_counter=0, + instructions="Test instructions", + dir_tree="Test dir tree", + current_breakpoints="Test breakpoints", + step_observation="Test last run obs", + ) + env.step.return_value = build_env_info( + done=False, + score=10, + max_score=10, + rewrite_counter=0, + instructions="Test instructions", + dir_tree="Test dir tree", + current_breakpoints="Test breakpoints", + step_observation="Test last run obs", + ) + + env.clone.return_value = MagicMock() + llm.return_value = LLMResponse("Prompt", "Expected answer", TokenUsage(2, 4)) + env.tools = {"pdb": MagicMock()} + + env.clone().step.return_value = build_env_info( + done=True, + score=10, + max_score=10, + rewrite_counter=0, + instructions="Test instructions", + dir_tree="Test dir tree", + current_breakpoints="Test breakpoints", + step_observation="Test last run obs", + ) + result = agent.run(task_name="test_task", debug=False) + + assert result is False + # test that llm actions were executed + assert env.step.called + env.step.assert_called_with(human().response) + assert env.step().done is False + + # test that llm actions were logged + _history, _prompt_response_pairs = agent.history.get() + assert [[], [human()]] == _prompt_response_pairs + + # test that env was cloned + assert env.clone.called + assert env.clone().reset.called + + # assert that cloned env was called with history steps + env.clone().step.assert_has_calls( + [ + call(agent.history.get_all()[0].action), + ] + ) + + # test that human action was executed + assert env.clone().step.called + env.clone().step.assert_called_with(llm().response) + + # ensure that human action was not recorded in history + assert env.clone().step() not in agent.history.get_all() diff --git a/tests/gym/envs/test_env.py b/tests/gym/envs/test_env.py index 4ff5a034..7662c25a 100644 --- a/tests/gym/envs/test_env.py +++ b/tests/gym/envs/test_env.py @@ -746,3 +746,22 @@ def test_has_breakpoint_relative_path(tmp_path): assert env.has_breakpoint("foo.py", 6) is False # Should return False for non-existent file assert env.has_breakpoint("bar.py", line_number) is False + + +def test_clone(env): + cloned_env = env.clone() + + # Check that the cloned environment is a different instance + assert id(env) != id(cloned_env) + + # Check that the properties are the same + assert env.path == cloned_env.path + assert env.entrypoint == cloned_env.entrypoint + assert env.debug_entrypoint == cloned_env.debug_entrypoint + assert env.max_score == cloned_env.max_score + assert env.run_timeout == cloned_env.run_timeout + assert env.dir_tree_depth == cloned_env.dir_tree_depth + assert env.logger == cloned_env.logger + + # Check that the terminal is not the same instance + assert id(env.terminal) != id(cloned_env.terminal)