diff --git a/README.md b/README.md index b5807314..b92d4753 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,50 @@ image to select a step and observe the action taken by the agent. **⚠️ Note**: Gradio is still developing, and unexpected behavior has been frequently noticed. Version 5.5 seems to work properly so far. If you're not sure that the proper information is displaying, refresh the page and select your experiment again. +### AgentLab Server and AgentLab Controller + +https://github.com/user-attachments/assets/9a498c99-453a-4d7c-89fc-13e18db8dad6 + +The AgentLab Server and Controller are two components that work together to control and debug an agent deployed in an environment. + +#### Prerequisites + +First, set a `.env` file at the root of the repo with the following content: + +```bash +# LLM Creds (Azure as an example) +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_API_VERSION= + +# ServiceNow dev instance creds +SNOW_INSTANCE_URL=https://.service-now.com/ +SNOW_INSTANCE_UNAME="admin" +SNOW_INSTANCE_PWD= + +# MiniWob +MINIWOB_URL="file:///path/to/BrowserGym/miniwob-plusplus/miniwob/html/miniwob/" +``` + +#### Launch the server + +The AgentLab Server is responsible for hosting and enabling interaction with the environment. It is a lightweight FastAPI server that handles the BrowserGym environment and provides a REST API for the controller. + +To launch the server, open a terminal and run (you will need to keep this terminal open for the next step): + +```bash +agentlab-server +``` + +#### Launch the controller + +The AgentLab Controller is a streamlit app responsible for controlling the agent and how it interacts with the environment hosted on the server. + +To launch the controller, open a new terminal and run: + +```bash +agentlab-controller +``` ## 🏆 Leaderboard diff --git a/pyproject.toml b/pyproject.toml index 1292836a..d040e0b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,5 @@ exclude = ''' [project.scripts] agentlab-assistant = "agentlab.ui_assistant:main" agentlab-xray = "agentlab.analyze.agent_xray:main" +agentlab-controller = "agentlab.analyze.run_agentlab_controller:main" +agentlab-server = "agentlab.analyze.server:main" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6322ffd3..d29159ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ ray[default] python-slugify pillow gymnasium>=0.27 +streamlit diff --git a/src/agentlab/agents/generic_agent/__init__.py b/src/agentlab/agents/generic_agent/__init__.py index cb5bbb7f..9aecbb5f 100644 --- a/src/agentlab/agents/generic_agent/__init__.py +++ b/src/agentlab/agents/generic_agent/__init__.py @@ -26,6 +26,14 @@ AGENT_o3_MINI, FLAGS_GPT_4o, GenericAgentArgs, + AGENT_AZURE_4o_MINI, + AGENT_AZURE_4o, + AGENT_AZURE_4o_VISION, + AGENT_AZURE_4o_MINI_VISION, + AGENT_AZURE_41, + AGENT_AZURE_41_MINI, + AGENT_AZURE_41_VISION, + AGENT_AZURE_41_MINI_VISION, ) __all__ = [ @@ -46,4 +54,12 @@ "AGENT_4o_VISION", "AGENT_4o_MINI_VISION", "AGENT_CLAUDE_SONNET_35_VISION", + "AGENT_AZURE_4o_MINI", + "AGENT_AZURE_4o", + "AGENT_AZURE_4o_VISION", + "AGENT_AZURE_4o_MINI_VISION", + "AGENT_AZURE_41", + "AGENT_AZURE_41_MINI", + "AGENT_AZURE_41_VISION", + "AGENT_AZURE_41_MINI_VISION", ] diff --git a/src/agentlab/agents/generic_agent/agent_configs.py b/src/agentlab/agents/generic_agent/agent_configs.py index f50367d8..17728a82 100644 --- a/src/agentlab/agents/generic_agent/agent_configs.py +++ b/src/agentlab/agents/generic_agent/agent_configs.py @@ -350,3 +350,41 @@ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], flags=DEFAULT_RS_FLAGS, ) + + +AGENT_AZURE_4o_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-mini"], + flags=FLAGS_GPT_4o, +) +AGENT_AZURE_4o = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o"], + flags=FLAGS_GPT_4o, +) +AGENT_AZURE_41 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1"], + flags=FLAGS_GPT_4o, +) +AGENT_AZURE_41_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-mini"], + flags=FLAGS_GPT_4o, +) + +AGENT_AZURE_4o_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_AZURE_4o_MINI_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-mini"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_AZURE_41_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_AZURE_41_MINI_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-mini"], + flags=FLAGS_GPT_4o_VISION, +) diff --git a/src/agentlab/agents/tool_use_agent/__init__.py b/src/agentlab/agents/tool_use_agent/__init__.py index 935fea14..2f4a0139 100644 --- a/src/agentlab/agents/tool_use_agent/__init__.py +++ b/src/agentlab/agents/tool_use_agent/__init__.py @@ -4,3 +4,12 @@ # for backward compatibility of unpickling sys.modules[__name__ + ".multi_tool_agent"] = sys.modules[__name__] + +__all__ = [ + "GPT_4_1", + "AZURE_GPT_4_1", + "GPT_4_1_MINI", + "AZURE_GPT_4_1_MINI", + "OPENAI_CHATAPI_MODEL_CONFIG", + "CLAUDE_MODEL_CONFIG", +] diff --git a/src/agentlab/agents/tool_use_agent/hint_db.csv b/src/agentlab/agents/tool_use_agent/hint_db.csv index f402c24a..76ee969d 100644 --- a/src/agentlab/agents/tool_use_agent/hint_db.csv +++ b/src/agentlab/agents/tool_use_agent/hint_db.csv @@ -21,3 +21,27 @@ July 13,workarena.servicenow.create-hardware-asset,385,gpt-4.1,ToolUse-gpt-4.1,W July 13,workarena.servicenow.create-hardware-asset,385,gpt-4.1,ToolUse-gpt-4.1,WorkArena-L1,WorkArena-L1,allac,Filling form in WorkArena,"Before clicking submit, make sure that all fields are filled properly. Then click submit." July 13,workarena.servicenow.create-hardware-asset,385,gpt-4.1,ToolUse-gpt-4.1,WorkArena-L1,WorkArena-L1,allac,Filling form in WorkArena,Avoid back and forth from tabs to tabs to reduce the number of actions July 14,workarena.servicenow.create-hardware-asset,385,gpt-4.1,ToolUse-gpt-4.1,WorkArena-L1,WorkArena-L1,allac,Filling form in WorkArena,When you see auto-complete make sure to select an element from that list +July 16,workarena.servicenow.sort-asset-list,406,gpt-4-1,ToolUseAgent-gpt-4-1,workarena,workarena,patricebechard,Sorting lists in ServiceNow,"1. **Navigate to Your Table/List** + + * For example, go to **Incident > All** or any other table you want to view. + +2. **Sort by One or Multiple Columns** + + * `click` on the ""show / hide filter"" button (funnel icon) at the top left of the page to open the filter row. + * Repeat the following steps for each column you want to sort by in this exact order: + * `click` on the ""Add Sort"" button to add a new sort filter. This will create a new ordering filter row with two comboboxes under the heading ""Order results by the following fields"". + * `fill` the first combobox with the appropriate field name you want to sort by. MAKE SURE to use the exact field name provided. + * `press` Enter after typing the field name to close the dropdown. It is VERY IMPORTANT that you do this before doing anything else otherwise the field will not be selected and the task will not be successful. DO NOT click on the run filter button before having confirmed your choice by explicitly pressing ENTER. + * `select_option` for the appropriate ordering between ascending (a to z) or descending (z to a) in the second combobox. + * Once all sort filters have been added, `click` the ""Run filter"" button to apply the sort. + +Notes: + * NEVER directly sort the columns using the table header. + * NEVER add columns via the Personalize List menu. + * ALWAYS sort the table using the EXACT NAMES of the provided fields. DO NOT use different but similar field names. For example, if the field you're asked to sort by is ""Opened by"", DO NOT filter by ""Created by"" even if they sound similar, but instead ALWAYS use the exact ""Opened by"" wording. + * Some columns might not appear by default in the visible view of the table. This does not mean they do not exist. ALWAYS use the EXACT names provided to sort by otherwise the task will not be successful. + +3. **Resetting or Clearing Sorting** + + * To reset sorting, click another column, or click again to toggle. + * In the filter bar, you may see a ""Sorted by..."" indicator—clear or change it as needed." diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index b7494693..d50a27e8 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -8,15 +8,6 @@ import bgym import pandas as pd -from bgym import Benchmark as BgymBenchmark -from browsergym.core.observation import extract_screenshot -from browsergym.utils.obs import ( - flatten_axtree_to_str, - flatten_dom_to_str, - overlay_som, - prune_html, -) - from agentlab.agents.agent_args import AgentArgs from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark from agentlab.benchmarks.osworld import OSWorldActionSet @@ -24,6 +15,7 @@ from agentlab.llm.llm_utils import image_to_png_base64_url from agentlab.llm.response_api import ( APIPayload, + AzureOpenAIResponseModelArgs, ClaudeResponseModelArgs, LLMOutput, MessageBuilder, @@ -33,6 +25,14 @@ ToolCalls, ) from agentlab.llm.tracking import cost_tracker_decorator +from bgym import Benchmark as BgymBenchmark +from browsergym.core.observation import extract_screenshot +from browsergym.utils.obs import ( + flatten_axtree_to_str, + flatten_dom_to_str, + overlay_som, + prune_html, +) @dataclass @@ -43,8 +43,8 @@ def _init(self): def make(self) -> "Block": """Returns a copy so the init can start adding some stuff to `self` without changing the - original datatclass that should only contain a config. - The aim is avoid having 2 calss definition for each block, e.g. Block and BlockArgs. + original dataclass that should only contain a config. + The aim is avoid having 2 class definitions for each block, e.g. Block and BlockArgs. Returns: Block: A copy of the current block instance with initialization applied. @@ -387,7 +387,6 @@ def __init__( self.config.action_subsets, multiaction=self.config.multiaction # type: ignore ) self.tools = self.action_set.to_tool_description(api=model_args.api) - self.call_ids = [] self.llm = model_args.make_model() @@ -508,6 +507,15 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +AZURE_GPT_4_1 = AzureOpenAIResponseModelArgs( + model_name="gpt-4.1", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + GPT_4_1_MINI = OpenAIResponseModelArgs( model_name="gpt-4.1-mini", max_total_tokens=200_000, @@ -517,6 +525,15 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +AZURE_GPT_4_1_MINI = AzureOpenAIResponseModelArgs( + model_name="gpt-4.1-mini", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs( model_name="gpt-4o-2024-08-06", max_total_tokens=200_000, @@ -576,9 +593,9 @@ def get_action(self, obs: Any) -> float: general_hints=GeneralHints(use_hints=False), task_hint=TaskHint(use_task_hint=True), keep_last_n_obs=None, - multiaction=True, # whether to use multi-action or not - # action_subsets=("bid",), - action_subsets=("coord"), + multiaction=False, # whether to use multi-action or not + action_subsets=("bid",), + # action_subsets=("coord"), # action_subsets=("coord", "bid"), ) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py new file mode 100644 index 00000000..8179a04b --- /dev/null +++ b/src/agentlab/analyze/agent_controller.py @@ -0,0 +1,1282 @@ +import base64 +import copy +import gzip +import importlib +import json +import logging +import os +import pickle +from collections import Counter +from datetime import datetime +from io import BytesIO +from pathlib import Path + +import numpy as np +import PIL.Image +import requests +import streamlit as st +from agentlab.agents.generic_agent import __all__ as ALL_GENERIC_AGENTS +from agentlab.agents.generic_agent.generic_agent import GenericAgent +from agentlab.agents.tool_use_agent import __all__ as ALL_TOOL_USE_AGENTS +from agentlab.agents.tool_use_agent.tool_use_agent import ( + DEFAULT_PROMPT_CONFIG, + ToolUseAgent, + ToolUseAgentArgs, +) +from agentlab.experiments.exp_utils import RESULTS_DIR +from agentlab.experiments.loop import ExpArgs, StepInfo, save_package_versions +from agentlab.llm.response_api import LLMOutput +from bgym import DEFAULT_BENCHMARKS +from dotenv import load_dotenv +from transformers import AutoTokenizer + +# used to display prompt. simple chat template from apache 2.0 model +# tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") +tokenizer = AutoTokenizer.from_pretrained( + "/Users/patrice.bechard/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/892b3d7a7b1cf10c7a701c60881cd93df615734c" +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +load_dotenv() + +DEFAULT_BENCHMARK = "workarena_l1" + +SERVER_URL = "http://127.0.0.1:8000" + +# region Sidebar Text +SIDEBAR_TEXT = """ +# AgentLab Controller + +AgentLab Controller is a tool used to help control and debug an agent deployed in an environment. + +AgentLab Controller works by connecting a Streamlit UI that handles the agent to a FastAPI backend server that handles the environment. + +--- + +## Instructions + +1. ⚙️ Setup the task + - Select an agent, benchmark, task, and subtask you want to work on. + - Select "🔄" to reset the environment. This includes resetting the environment server. + - Select "▶️" to start the environment. This will start the environment by opening a browser in the background. This step might take some time + +2. 🎮 Control the environment + - Look at the goal set for the task, the thought of the model, and the action taken. + - If the action looks right, select the "▶️ Next Step" button to step the environment. + + The action will then be executed and the environment will be updated. + - If the action is wrong and you want to re-prompt, select the "🔄 Regenerate Action". + + You can also expand the "Prompt Modifier" menu to change the prompt used to generate the thoughts / actions. + - If you want to backtrack and undo the previous actions, select the "⬅️ Previous Step" button. + + Note: This is a slow process as we need to reset the environment server and perform the previous actions one by one. + +3. 🔎 Investigate the environment + - Look at the screenshot of the current environment state + - Verify that the action selected by the model matches the AxTree + - Ensure that the prompt is properly build. If there are issues with the prompt yielding the wrong action, modify them using the "Prompt Modifier" above. +""" +# endregion + + +class Constants: + STATUS = "status" + STATUS_SUCCESS = "success" + STATUS_ERROR = "error" + MESSAGE = "message" + + OBS = "obs" + SCREENSHOT = "screenshot" + AXTREE_TXT = "axtree_txt" + + +class IgnoreMessageFilter(logging.Filter): + def filter(self, record): + return "but it does not exist!" not in record.getMessage() + + +streamlit_logger = st.watcher.local_sources_watcher._LOGGER +streamlit_logger.setLevel(logging.ERROR) + + +def make_hashable(obj): + if isinstance(obj, np.ndarray): + # Use shape, dtype, and bytes for uniqueness + return (obj.shape, obj.dtype.str, obj.tobytes()) + elif isinstance(obj, (tuple, list)): + return tuple(make_hashable(x) for x in obj) + elif isinstance(obj, dict): + # Sort keys to ensure consistent order + return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) + else: + return obj # Assume it's already hashable + + +def is_json_serializable(value): + try: + json.dumps(value) + return True + except (TypeError, OverflowError): + return False + + +def get_import_path(obj): + return f"{obj.__module__}.{obj.__qualname__}" + + +def deserialize_response(response_json): + if Constants.OBS in response_json: + if Constants.SCREENSHOT in response_json[Constants.OBS]: + screenshot_data = response_json[Constants.OBS][Constants.SCREENSHOT] + # convert base64 to numpy array + screenshot = np.frombuffer( + base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]) + ) + screenshot = screenshot.reshape(screenshot_data["shape"]) + response_json[Constants.OBS][Constants.SCREENSHOT] = screenshot + return response_json + + +def reset_env_history(): + logger.info("Resetting env history") + st.session_state.last_obs = None + st.session_state.obs_history = [] + st.session_state.screenshot_history = [] + st.session_state.axtree_history = [] + + # related to env info + st.session_state.reward_history = [] + st.session_state.terminated_history = [] + st.session_state.truncated_history = [] + st.session_state.env_info_history = [] + + +def reset_agent_history(): + logger.info("Resetting agent history") + st.session_state.action = None + st.session_state.action_info = None + st.session_state.action_history = [] + st.session_state.action_info_history = [] + st.session_state.thought_history = [] + st.session_state.prompt_history = [] + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history = [] + + +def reset_agent_state(): + logger.info("Resetting agent state") + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.agent.reset() + else: + st.session_state.agent.discussion.groups = [] + st.session_state.agent.last_response = LLMOutput() + st.session_state.agent._responses = [] + + +def step_env_history(obs, response_json): + logger.info("Stepping env history") + st.session_state.last_obs = copy.deepcopy(obs) + st.session_state.obs_history.append(obs) + st.session_state.screenshot_history.append(obs[Constants.SCREENSHOT]) + st.session_state.axtree_history.append(obs[Constants.AXTREE_TXT]) + + # other relevant info found in response_json + st.session_state.reward_history.append(response_json["reward"]) + st.session_state.terminated_history.append(response_json["terminated"]) + st.session_state.truncated_history.append(response_json["truncated"]) + st.session_state.env_info_history.append(response_json["info"]) + + +def step_agent_history(action, action_info): + logger.info("Stepping agent history") + st.session_state.action = copy.deepcopy(action) + st.session_state.action_info = copy.deepcopy(action_info) + st.session_state.action_history.append(action) + st.session_state.action_info_history.append(action_info) + st.session_state.thought_history.append(action_info.think) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.prompt_history.append(get_prompt(action_info)) + elif isinstance(st.session_state.agent, ToolUseAgent): + st.session_state.prompt_history.append( + "\n".join([elem.to_markdown() for elem in st.session_state.agent.discussion.flatten()]) + ) + + # HACK: memory history can only be obtained via the agent + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history.append(st.session_state.agent.memories[-1]) + + +def set_agent_state(): + logger.info("Setting agent state") + st.session_state.agent.obs_history = copy.deepcopy(st.session_state.obs_history) + st.session_state.agent.actions = copy.deepcopy(st.session_state.action_history) + st.session_state.agent.thoughts = copy.deepcopy(st.session_state.thought_history) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.agent.memories = copy.deepcopy(st.session_state.memory_history) + + +def revert_env_history(): + logger.info("Reverting env history") + st.session_state.obs_history.pop() + st.session_state.screenshot_history.pop() + st.session_state.axtree_history.pop() + + # related to env info + st.session_state.reward_history.pop() + st.session_state.terminated_history.pop() + st.session_state.truncated_history.pop() + st.session_state.env_info_history.pop() + + +def revert_agent_history(): + logger.info("Reverting agent history") + st.session_state.action_history.pop() + st.session_state.action_info_history.pop() + st.session_state.thought_history.pop() + st.session_state.prompt_history.pop() + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history.pop() + + +def revert_agent_state(): + logger.info("Reverting agent state") + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.agent.obs_history.pop() + st.session_state.agent.actions.pop() + st.session_state.agent.thoughts.pop() + st.session_state.agent.memories.pop() + elif isinstance(st.session_state.agent, ToolUseAgent): + num_groups = len(st.session_state.agent.discussion.groups) + if num_groups == 3: + # start from blank state + reset_agent_state() + elif num_groups > 3: + # get rid of the last group (last action), and remove everything from the other previous group except for the action + st.session_state.agent.discussion.groups.pop() + last_group = copy.deepcopy(st.session_state.agent.discussion.groups[-1]) + last_group.summary = None + last_group.messages = last_group.messages[:0] # remove everything from last group + st.session_state.agent.discussion.groups[-1] = last_group + st.session_state.agent._responses.pop() + st.session_state.agent.last_response = copy.deepcopy( + st.session_state.agent._responses[-1] + ) + else: + raise Exception("Invalid number of groups") + + +def restore_env_history(step: int): + logger.info(f"Restoring env history to step {step}") + st.session_state.obs_history = copy.deepcopy(st.session_state.obs_history[:step]) + st.session_state.screenshot_history = copy.deepcopy(st.session_state.screenshot_history[:step]) + st.session_state.axtree_history = copy.deepcopy(st.session_state.axtree_history[:step]) + + # related to env info + st.session_state.reward_history = copy.deepcopy(st.session_state.reward_history[:step]) + st.session_state.terminated_history = copy.deepcopy(st.session_state.terminated_history[:step]) + st.session_state.truncated_history = copy.deepcopy(st.session_state.truncated_history[:step]) + st.session_state.env_info_history = copy.deepcopy(st.session_state.env_info_history[:step]) + + +def restore_agent_history(step: int): + logger.info(f"Restoring agent history to step {step}") + st.session_state.action_history = copy.deepcopy(st.session_state.action_history[:step]) + st.session_state.action_info_history = copy.deepcopy( + st.session_state.action_info_history[:step] + ) + st.session_state.thought_history = copy.deepcopy(st.session_state.thought_history[:step]) + st.session_state.prompt_history = copy.deepcopy(st.session_state.prompt_history[:step]) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history = copy.deepcopy(st.session_state.memory_history[:step]) + + +def get_prompt(info): + if info is not None: + if hasattr(info, "chat_messages") and isinstance(info.chat_messages, Discussion): + chat_messages = info.chat_messages.messages + new_chat_messages = [] + for message in chat_messages: + if isinstance(message["content"], list): + # concatenate all text elements + new_chat_messages.append( + { + "role": message["role"], + "content": "\n\n".join( + [ + elem["text"] + for elem in message["content"] + if elem["type"] == "text" + ] + ), + } + ) + else: + new_chat_messages.append(message) + prompt = tokenizer.apply_chat_template( + new_chat_messages, add_special_tokens=True, tokenize=False + ) + return prompt + else: + prompt = "Not implemented yet for Response API" + return prompt + + +def setup_sidebar(): + with st.sidebar: + st.markdown(SIDEBAR_TEXT) + + +def set_session_state(): + + # args used to instantiate agent / environment + if "has_submitted_configs" not in st.session_state: + st.session_state.has_submitted_configs = False + if "agent_args" not in st.session_state: + st.session_state.agent_args = None + if "benchmark" not in st.session_state: + st.session_state.benchmark = None + if "task" not in st.session_state: + st.session_state.task = None + if "subtask" not in st.session_state: + st.session_state.subtask = None + if "env_args" not in st.session_state: + st.session_state.env_args = None + + # current state + if "agent" not in st.session_state: + st.session_state.agent = None + if "action" not in st.session_state: + st.session_state.action = None + if "action_info" not in st.session_state: + st.session_state.action_info = None + if "last_obs" not in st.session_state: + st.session_state.last_obs = None + + # track history + if "prompt_history" not in st.session_state: + st.session_state.prompt_history = [] + if "screenshot_history" not in st.session_state: + st.session_state.screenshot_history = [] + if "axtree_history" not in st.session_state: + st.session_state.axtree_history = [] + if "thought_history" not in st.session_state: + st.session_state.thought_history = [] + if "memory_history" not in st.session_state: + st.session_state.memory_history = [] + if "action_history" not in st.session_state: + st.session_state.action_history = [] + if "action_info_history" not in st.session_state: + st.session_state.action_info_history = [] + if "obs_history" not in st.session_state: + st.session_state.obs_history = [] + if "reward_history" not in st.session_state: + st.session_state.reward_history = [] + if "terminated_history" not in st.session_state: + st.session_state.terminated_history = [] + if "truncated_history" not in st.session_state: + st.session_state.truncated_history = [] + if "env_info_history" not in st.session_state: + st.session_state.env_info_history = [] + + if "task_to_benchmark_mapping" not in st.session_state: + st.session_state.task_to_benchmark_mapping = {} + for benchmark in list(DEFAULT_BENCHMARKS.keys()): + all_tasks = set( + [elem.task_name for elem in DEFAULT_BENCHMARKS[benchmark]().env_args_list] + ) + for task in all_tasks: + st.session_state.task_to_benchmark_mapping[task] = benchmark + + if "has_clicked_prev" not in st.session_state: + st.session_state.has_clicked_prev = False + if "has_clicked_next" not in st.session_state: + st.session_state.has_clicked_next = False + if "has_clicked_multiple_reprompt" not in st.session_state: + st.session_state.has_clicked_multiple_reprompt = False + + +def select_agent_type(): + """Dropdown to select an agent type.""" + agent_type = st.selectbox("Select Agent Type", ["GenericAgent", "ToolUseAgent"], index=0) + return agent_type + + +def select_agent(agent_type: str = "GenericAgent"): + """Dropdown to select an agent.""" + if agent_type == "GenericAgent": + agent_choices = ALL_GENERIC_AGENTS + default_agent = "AGENT_AZURE_4o" + agent_str = st.selectbox( + "Select Agent", agent_choices, index=agent_choices.index(default_agent) + ) + agents_module = importlib.import_module("agentlab.agents.generic_agent") + agent = getattr(agents_module, agent_str) + elif agent_type == "ToolUseAgent": + agent_choices = ALL_TOOL_USE_AGENTS + default_agent = "AZURE_GPT_4_1" + agent_str = st.selectbox( + "Select Agent", agent_choices, index=agent_choices.index(default_agent) + ) + agents_module = importlib.import_module("agentlab.agents.tool_use_agent.tool_use_agent") + model_args = getattr(agents_module, agent_str) + agent = ToolUseAgentArgs( + model_args=model_args, + config=copy.deepcopy(DEFAULT_PROMPT_CONFIG), + ) + else: + st.error("Invalid agent type") + return agent + + +def select_benchmark() -> str: + """Dropdown to select a benchmark.""" + all_benchmarks = list(DEFAULT_BENCHMARKS.keys()) + benchmark_str = st.selectbox( + "Select Benchmark", all_benchmarks, index=all_benchmarks.index(DEFAULT_BENCHMARK) + ) + return benchmark_str + + +def select_task(benchmark): + """Dropdown to select a task based on the benchmark.""" + all_tasks = sorted(list(set([elem.task_name for elem in benchmark.env_args_list]))) + task_str = st.selectbox("Select Task", all_tasks) + return task_str + + +def select_subtask(benchmark, task_str) -> str: + """Dropdown to select a subtask based on the task name.""" + all_subtasks = sorted( + [str(elem.task_seed) for elem in benchmark.env_args_list if elem.task_name == task_str] + ) + subtask_str = st.selectbox("Select Subtask", all_subtasks) + return subtask_str + + +def set_task_selector(): + """Create task selector form. Allows the user to select the agent, benchmark, task, and subtask to run.""" + with st.container(border=True): + st.markdown("##### ⚙️ Select") + with st.form("Task Selector"): + col1, col2, col3, col4, col5, col6, col7 = st.columns( + [2, 2, 2, 3, 1, 1, 1], vertical_alignment="bottom" + ) + with col1: + selected_agent_type = select_agent_type() + with col2: + selected_agent_args = select_agent(selected_agent_type) + with col3: + selected_benchmark_str = select_benchmark() + selected_benchmark = DEFAULT_BENCHMARKS[selected_benchmark_str]() + with col4: + selected_task_str = select_task(selected_benchmark) + with col5: + selected_subtask_str = select_subtask(selected_benchmark, selected_task_str) + with col6: + if st.form_submit_button("🔄", use_container_width=True): + clean_session() + with col7: + if st.form_submit_button("▶️", use_container_width=True): + + # saving configs related to agent and task + st.session_state.has_submitted_configs = True + st.session_state.agent_args = selected_agent_args + st.session_state.benchmark = selected_benchmark_str + st.session_state.task = selected_task_str + st.session_state.subtask = selected_subtask_str + + st.session_state.env_args = [ + elem + for elem in selected_benchmark.env_args_list + if elem.task_name == selected_task_str + and str(elem.task_seed) == str(selected_subtask_str) + ][0] + + reset_env_history() + reset_agent_history() + + prepare_agent() + set_environment_info() + prepare_benchmark() + reset_environment() + # alternatively, one can load a file from disk to load a previous session + with st.expander(label="Load a previous run", expanded=False): + with st.form("Load Previous Run"): + col1, col2 = st.columns( + (11, 1), + vertical_alignment="top", + border=False, + ) + with col1: + exp_files = st.file_uploader( + "Select all files from a previous run directory", + accept_multiple_files=True, + label_visibility="collapsed", + ) + with col2: + if st.form_submit_button( + "⬆️", + use_container_width=True, + ): + if exp_files: + with st.spinner("Loading session..."): + load_session(exp_files) + + +def load_session(exp_files): + logger.info(f"Loading session...") + start = datetime.now() + + # load env and agent args + exp_args_files = [file for file in exp_files if file.name == "exp_args.pkl"] + if len(exp_args_files) == 0: + st.error("No exp_args.pkl file found in the selected directory.") + return + exp_args = exp_args_files[0].getvalue() + exp_args = pickle.loads(exp_args) + st.session_state.agent_args = exp_args.agent_args + st.session_state.env_args = exp_args.env_args + st.session_state.benchmark = st.session_state.task_to_benchmark_mapping[ + exp_args.env_args.task_name + ] + st.session_state.task = exp_args.env_args.task_name + st.session_state.subtask = exp_args.env_args.task_seed + + # load state from step files + screenshot_file_names = [ + file.name for file in exp_files if file.name.startswith("screenshot_step_") + ] + step_files = [file for file in exp_files if file.name.startswith("step_")] + if len(step_files) == 0: + st.error("No step files found in the selected directory.") + return + # sort step files + step_files.sort(key=lambda x: int(x.name.split("_")[-1].split(".")[0])) + # only keep step files for which we have an associated `screenshot_step_n.png` + step_files = [ + file + for file in step_files + if f"screenshot_{file.name.split('.')[0]}.png" in screenshot_file_names + ] + for file in step_files: + with gzip.open(file, "rb") as f: + step_info = pickle.load(f) + st.session_state.action_history.append(step_info.action) + st.session_state.action_info_history.append(step_info.agent_info) + st.session_state.thought_history.append(step_info.agent_info.get("think", None)) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history.append(step_info.agent_info.get("memory", None)) + st.session_state.prompt_history.append(get_prompt(step_info.agent_info)) + elif isinstance(st.session_state.agent, ToolUseAgent): + st.session_state.prompt_history.append( + "\n".join( + [elem.to_markdown() for elem in st.session_state.agent.discussion.flatten()] + ) + ) + else: + raise ValueError(f"Unknown agent type: {type(st.session_state.agent)}") + st.session_state.obs_history.append(step_info.obs) + st.session_state.reward_history.append(step_info.reward) + st.session_state.terminated_history.append(step_info.terminated) + st.session_state.truncated_history.append(step_info.truncated) + st.session_state.env_info_history.append( + {"task_info": step_info.task_info, "RAW_REWARD_GLOBAL": step_info.raw_reward} + ) + st.session_state.last_obs = st.session_state.obs_history[-1] + + # set environment in right state + prepare_agent() + reset_env_history() + set_environment_info() + prepare_benchmark() + reset_environment() + restore_environment() + end = datetime.now() + logger.info(f"Done in {end - start}") + st.rerun() + + +def clean_session(): + logger.info("Cleaning session...") + start = datetime.now() + requests.post(f"{SERVER_URL}/unset_info") + requests.post(f"{SERVER_URL}/close") + for key in list(st.session_state.keys()): + del st.session_state[key] + end = datetime.now() + logger.info(f"Done in {end - start}") + st.rerun() + + +def prepare_agent(): + st.session_state.agent_args.prepare() + st.session_state.agent = st.session_state.agent_args.make_agent() + if isinstance(st.session_state.agent, ToolUseAgent): + st.session_state.agent.set_task_name(st.session_state.task) + + +def set_environment_info(): + action_mapping_fn = get_import_path(st.session_state.agent.action_set.to_python_code) + payload = { + "benchmark_name": st.session_state.benchmark, + "task_name": st.session_state.task, + "seed": st.session_state.subtask, + "action_mapping_fn": action_mapping_fn, + "exp_dir": str(RESULTS_DIR), + } + resp = requests.post(f"{SERVER_URL}/set_info", json=payload) + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: + st.error(resp.json()) + + +def prepare_benchmark(): + logger.info("Preparing benchmark...") + start = datetime.now() + resp = requests.post(f"{SERVER_URL}/prepare_benchmark") + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: + st.error(resp.json()) + end = datetime.now() + logger.info(f"Done in {end - start}") + + +def reset_environment(): + logger.info("Restarting environment...") + start = datetime.now() + resp = requests.post(f"{SERVER_URL}/reset") + end = datetime.now() + logger.info(f"Done request in {end - start}") + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: + logger.error(resp.status_code) + logger.error(resp.json()[Constants.STATUS]) + logger.error(resp.json()[Constants.MESSAGE]) + response_json = resp.json() + response_json = deserialize_response(response_json) + obs = response_json[Constants.OBS] + if st.session_state.agent.obs_preprocessor: + obs = st.session_state.agent.obs_preprocessor(obs) + step_env_history(obs, response_json) + st.session_state.action = None + st.session_state.action_info = None + + +def reload_task(): + logger.info("Reloading task...") + start = datetime.now() + resp = requests.post(f"{SERVER_URL}/reload_task") + end = datetime.now() + logger.info(f"Done request in {end - start}") + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: + logger.error(resp.status_code) + logger.error(resp.json()[Constants.STATUS]) + logger.error(resp.json()[Constants.MESSAGE]) + response_json = resp.json() + response_json = deserialize_response(response_json) + obs = response_json[Constants.OBS] + if st.session_state.agent.obs_preprocessor: + obs = st.session_state.agent.obs_preprocessor(obs) + step_env_history(obs, response_json) + st.session_state.action = None + st.session_state.action_info = None + + +def step_environment(action): + logger.info("Stepping environment...") + start = datetime.now() + payload = {"action": action} + resp = requests.post(f"{SERVER_URL}/step", json=payload) + end = datetime.now() + logger.info(f"Done request in {end - start}") + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: + logger.error(resp.status_code) + logger.error(resp.json()[Constants.STATUS]) + logger.error(resp.json()[Constants.MESSAGE]) + response_json = resp.json() + response_json = deserialize_response(response_json) + obs = response_json[Constants.OBS] + if st.session_state.agent.obs_preprocessor: + obs = st.session_state.agent.obs_preprocessor(obs) + step_env_history(obs, response_json) + st.session_state.action = None + st.session_state.action_info = None + + +def restore_environment(): + reload_task() + for action in st.session_state.action_history[:-1]: + step_environment(action) + st.session_state.action = st.session_state.action_history[-1] + st.session_state.action_info = st.session_state.action_info_history[-1] + set_agent_state() + + +def get_action(): + logger.info("Getting action...") + start = datetime.now() + action, info = st.session_state.agent.get_action(copy.deepcopy(st.session_state.last_obs)) + step_agent_history(action, info) + end = datetime.now() + logger.info(f"Done in {end - start}") + + +def set_agent_state_box(): + + # Custom CSS to set textarea style same as code block + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + # set agent state and goal box + with st.container(): + col1, col2, col3 = st.columns([1, 1, 1]) + with col1: + with st.container(border=True, height=250): + st.markdown("**Goal**") + st.code( + st.session_state.last_obs["goal"], + wrap_lines=True, + language=None, + height=175, + ) + with col2: + with st.container(border=True, height=250): + st.markdown("**Think**") + initial_think = copy.deepcopy(st.session_state.action_info.think) + st.session_state.action_info.think = st.text_area( + "Think", + st.session_state.action_info.think, + height=172, + label_visibility="collapsed", + ) + if st.session_state.action_info.think != initial_think: + # if thought has been updated, update thought history + st.session_state.thought_history[-1] = copy.deepcopy( + st.session_state.action_info.think + ) + st.session_state.agent.thoughts[-1] = copy.deepcopy( + st.session_state.action_info.think + ) + with col3: + with st.container(border=True, height=250): + st.markdown("**Action**") + initial_action = copy.deepcopy(st.session_state.action) + st.session_state.action = st.text_area( + "Action", st.session_state.action, height=172, label_visibility="collapsed" + ) + if st.session_state.action != initial_action: + # if action has been updated, update action history + st.session_state.action_history[-1] = copy.deepcopy(st.session_state.action) + st.session_state.agent.actions[-1] = copy.deepcopy(st.session_state.action) + + +def set_prompt_modifier(): + with st.expander("**Prompt Modifier**", expanded=False): + if isinstance(st.session_state.agent, GenericAgent): + st.markdown("**Observation Flags**") + col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) + with col1: + st.session_state.agent.flags.obs.use_html = st.checkbox( + "use_html", value=st.session_state.agent.flags.obs.use_html + ) + st.session_state.agent.flags.obs.use_action_history = st.checkbox( + "use_action_history", value=st.session_state.agent.flags.obs.use_action_history + ) + with col2: + st.session_state.agent.flags.obs.use_ax_tree = st.checkbox( + "use_ax_tree", value=st.session_state.agent.flags.obs.use_ax_tree + ) + st.session_state.agent.flags.obs.use_think_history = st.checkbox( + "use_think_history", value=st.session_state.agent.flags.obs.use_think_history + ) + with col3: + st.session_state.agent.flags.obs.use_focused_element = st.checkbox( + "use_focused_element", + value=st.session_state.agent.flags.obs.use_focused_element, + ) + st.session_state.agent.flags.obs.use_diff = st.checkbox( + "use_diff", value=st.session_state.agent.flags.obs.use_diff + ) + with col4: + st.session_state.agent.flags.obs.use_error_logs = st.checkbox( + "use_error_logs", value=st.session_state.agent.flags.obs.use_error_logs + ) + st.session_state.agent.flags.obs.use_screenshot = st.checkbox( + "use_screenshot", value=st.session_state.agent.flags.obs.use_screenshot + ) + with col5: + st.session_state.agent.flags.obs.use_history = st.checkbox( + "use_history", value=st.session_state.agent.flags.obs.use_history + ) + st.session_state.agent.flags.obs.use_som = st.checkbox( + "use_som", value=st.session_state.agent.flags.obs.use_som + ) + with col6: + st.session_state.agent.flags.obs.use_past_error_logs = st.checkbox( + "use_past_error_logs", + value=st.session_state.agent.flags.obs.use_past_error_logs, + ) + st.session_state.agent.flags.obs.use_tabs = st.checkbox( + "use_tabs", value=st.session_state.agent.flags.obs.use_tabs + ) + st.markdown("**Other Flags**") + col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) + with col1: + st.session_state.agent.flags.use_plan = st.checkbox( + "use_plan", value=st.session_state.agent.flags.use_plan + ) + st.session_state.agent.flags.use_hints = st.checkbox( + "use_hints", value=st.session_state.agent.flags.use_hints + ) + with col2: + st.session_state.agent.flags.use_criticise = st.checkbox( + "use_criticise", value=st.session_state.agent.flags.use_criticise + ) + st.session_state.agent.flags.be_cautious = st.checkbox( + "be_cautious", value=st.session_state.agent.flags.be_cautious + ) + with col3: + st.session_state.agent.flags.use_thinking = st.checkbox( + "use_thinking", value=st.session_state.agent.flags.use_thinking + ) + st.session_state.agent.flags.enable_chat = st.checkbox( + "enable_chat", value=st.session_state.agent.flags.enable_chat + ) + with col4: + st.session_state.agent.flags.use_memory = st.checkbox( + "use_memory", value=st.session_state.agent.flags.use_memory + ) + with col5: + st.session_state.agent.flags.use_abstract_example = st.checkbox( + "use_abstract_example", value=st.session_state.agent.flags.use_abstract_example + ) + with col6: + st.session_state.agent.flags.use_concrete_example = st.checkbox( + "use_concrete_example", value=st.session_state.agent.flags.use_concrete_example + ) + extra_instructions = st.text_area( + "extra_instructions", value=st.session_state.agent.flags.extra_instructions + ) + if extra_instructions == "": + extra_instructions = None + st.session_state.agent.flags.extra_instructions = extra_instructions + elif isinstance(st.session_state.agent, ToolUseAgent): + + st.session_state.agent.config.tag_screenshot = st.checkbox( + "Tag screenshot", value=st.session_state.agent.config.tag_screenshot + ) + + # Goal flags + st.session_state.agent.config.goal.goal_as_system_msg = st.checkbox( + "Goal as system message", + value=st.session_state.agent.config.goal.goal_as_system_msg, + ) + + # Obs flags + st.session_state.agent.config.obs.use_last_error = st.checkbox( + "Use last error", value=st.session_state.agent.config.obs.use_last_error + ) + st.session_state.agent.config.obs.use_screenshot = st.checkbox( + "Use screenshot", value=st.session_state.agent.config.obs.use_screenshot + ) + st.session_state.agent.config.obs.use_axtree = st.checkbox( + "Use axtree", value=st.session_state.agent.config.obs.use_axtree + ) + st.session_state.agent.config.obs.use_dom = st.checkbox( + "Use dom", value=st.session_state.agent.config.obs.use_dom + ) + st.session_state.agent.config.obs.use_som = st.checkbox( + "Use som", value=st.session_state.agent.config.obs.use_som + ) + st.session_state.agent.config.obs.use_tabs = st.checkbox( + "Use tabs", value=st.session_state.agent.config.obs.use_tabs + ) + # st.session_state.agent.config.obs.add_mouse_pointer = st.checkbox( + # "Add mouse pointer", value=st.session_state.agent.config.obs.add_mouse_pointer + # ) + st.session_state.agent.config.obs.use_zoomed_webpage = st.checkbox( + "Use zoomed webpage", value=st.session_state.agent.config.obs.use_zoomed_webpage + ) + + # Summarizer flags + st.session_state.agent.config.summarizer.do_summary = st.checkbox( + "Do summary", value=st.session_state.agent.config.summarizer.do_summary + ) + st.session_state.agent.config.summarizer.high_details = st.checkbox( + "Summarize with high details", + value=st.session_state.agent.config.summarizer.high_details, + ) + + # General Hints flags + st.session_state.agent.config.general_hints.use_hints = st.checkbox( + "Use general hints", value=st.session_state.agent.config.general_hints.use_hints + ) + + # Task Hint flags + st.session_state.agent.config.task_hint.use_task_hint = st.checkbox( + "Use task hint", value=st.session_state.agent.config.task_hint.use_task_hint + ) + + # general + st.session_state.agent.config.keep_last_n_obs = st.number_input( + "Keep last n obs", value=st.session_state.agent.config.keep_last_n_obs + ) + st.session_state.agent.config.multiaction = st.checkbox( + "Multiaction", value=st.session_state.agent.config.multiaction + ) + # st.session_state.agent.config.action_subsets = st.text_area( + # "Action subsets", value=st.session_state.agent.config.action_subsets + # ) + + +def set_go_back_to_step_n_section(): + with st.container(border=True): + st.markdown("**Go Back to Step N**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + is_go_back_to_step_n_disabled = len(st.session_state.action_history) <= 1 + with col1: + step = st.number_input( + "Step", + value=1, + min_value=1, + max_value=len(st.session_state.action_history), + disabled=is_go_back_to_step_n_disabled, + ) + with col2: + if st.button( + "⬅️ Go Back", + help="Go back to step N", + use_container_width=True, + disabled=is_go_back_to_step_n_disabled, + ): + logger.info(f"Going back to step {step}") + reset_agent_state() + restore_agent_history(step=step) + reset_env_history() + restore_environment() + st.rerun() + + +def set_regenerate_action_n_times_section(): + with st.container(border=True): + st.markdown("**Regenerate Action N Times**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + with col1: + n = st.number_input( + "Number of Actions to Generate", + value=5, + min_value=1, + max_value=25, + ) + with col2: + st.session_state.has_clicked_multiple_reprompt = st.button( + "🔄 Regenerate", + help="Reprompt the agent K times to get a distribution of actions to take", + use_container_width=True, + ) + if st.session_state.has_clicked_multiple_reprompt: + logger.info(f"Regenerating action {n} times...") + reprompt_actions = [] + action_to_info_mapping = {} + action_to_memory_mapping = {} + progress_bar = st.progress(0, text=f"Regenerating action {n} times...") + for i in range(n): + progress_bar.progress((i + 1) / n, text=f"Regenerating action {i + 1} of {n}...") + revert_agent_history() + revert_agent_state() + get_action() + reprompt_actions.append(st.session_state.action) + action_to_info_mapping[st.session_state.action] = copy.deepcopy( + st.session_state.action_info + ) + action_to_memory_mapping[st.session_state.action] = copy.deepcopy( + st.session_state.agent.memories[-1] + ) + progress_bar.progress(1, text=f"Regenerating action {n} times...") + progress_bar.empty() + # show all unique actions found in reprompt actions along with their probability + unique_actions_counter = Counter(reprompt_actions) + unique_actions = sorted( + unique_actions_counter.items(), key=lambda x: x[1], reverse=True + ) + st.markdown("**Unique Actions**") + for action, count in unique_actions: + has_clicked_reprompted_action = st.button(f"`{action}` ({count / n * 100:.2f}%)") + if has_clicked_reprompted_action: + logger.info(f"Selected action: {action} -- stepping") + st + revert_agent_history() + revert_agent_state() + + # manually step agent state + st.session_state.agent.obs_history.append( + copy.deepcopy(st.session_state.last_obs) + ) + st.session_state.agent.actions.append(action) + st.session_state.agent.thoughts.append(action_to_info_mapping[action].think) + st.session_state.agent.memories.append(action_to_memory_mapping[action]) + + step_agent_history(action, action_to_info_mapping[action]) + # step_environment(action) + st.session_state.has_clicked_multiple_reprompt = False + st.rerun() + + +def set_act_k_times_section(): + with st.container(border=True): + st.markdown("**Go Forward N Steps**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + with col1: + n = st.number_input("Number of Steps", value=5, min_value=1, max_value=10) + with col2: + has_clicked_act = st.button( + "➡️ Go Forward", + help="Let the agent autonomously perform actions for N steps", + use_container_width=True, + ) + if has_clicked_act: + logger.info(f"Going forward {n} steps...") + progress_bar = st.progress(0, text=f"Going forward {n} steps...") + for i in range(n): + if st.session_state.action is None: # so that we don't do it for first step + get_action() + step_environment(st.session_state.action) + progress_bar.progress((i + 1) / n, text=f"Going forward {i + 1} of {n}...") + progress_bar.empty() + st.rerun() + + +def set_advanced_controller(): + with st.expander("**Advanced**", expanded=False): + col_go_back_to, col_reprompt_k, col_act_k = st.columns([1, 1, 1]) + with col_go_back_to: + set_go_back_to_step_n_section() + with col_reprompt_k: + set_regenerate_action_n_times_section() + with col_act_k: + set_act_k_times_section() + + +def set_previous_step_section(): + prev_disabled = len(st.session_state.action_history) <= 1 + if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True): + if not prev_disabled: + logger.info("Clicked previous step") + st.session_state.action = ( + None + if len(st.session_state.action_history) == 0 + else st.session_state.action_history[-1] + ) + reset_agent_state() + revert_agent_history() + reset_env_history() + restore_environment() + st.rerun() + + +def set_regenerate_action_section(): + if st.button("🔄 Regenerate Action", use_container_width=True): + logger.info("Clicked regenerate action") + revert_agent_history() + revert_agent_state() + get_action() + st.rerun() + + +def set_next_step_section(): + if st.button("➡️ Next Step", use_container_width=True): + logger.info("Clicked next step") + step_environment(st.session_state.action) + st.rerun() + + +def set_controller(): + with st.container(border=True): + st.markdown("##### 🕹️ Control") + set_agent_state_box() + set_prompt_modifier() + col_prev, col_redo, col_next = st.columns([1, 1, 1]) + with col_prev: + set_previous_step_section() + with col_redo: + set_regenerate_action_section() + with col_next: + set_next_step_section() + set_advanced_controller() + + +def get_base64_serialized_image(img_arr): + if isinstance(img_arr, list): + img_arr = np.array(img_arr) + if isinstance(img_arr, np.ndarray): + im = PIL.Image.fromarray(img_arr) + buffered = BytesIO() + im.save(buffered, format="PNG") + img_b64 = base64.b64encode(buffered.getvalue()).decode() + return img_b64 + return None + + +def display_image(img_arr): + img_b64 = get_base64_serialized_image(img_arr) + if img_b64: + st.markdown( + f'
', + unsafe_allow_html=True, + ) + + +def set_screenshot_tab(): + display_image(st.session_state.screenshot_history[-1]) + + +def set_axtree_tab(): + st.code(st.session_state.axtree_history[-1], language=None, wrap_lines=True) + + +def set_prompt_tab(): + if isinstance(st.session_state.agent, GenericAgent): + st.code(st.session_state.prompt_history[-1], language=None, wrap_lines=True) + elif isinstance(st.session_state.agent, ToolUseAgent): + st.markdown(st.session_state.prompt_history[-1]) + + st.markdown(f"## Last summary:\n{st.session_state.agent.discussion.get_last_summary()}") + else: + raise ValueError(f"Unknown agent type: {type(st.session_state.agent)}") + + +def set_previous_steps_tab(): + for i in range(len(st.session_state.action_history) - 1): + with st.expander(f"### Step {i + 1}", expanded=False): + if st.button(f"Go back to step {i + 1}"): + logger.info(f"Go back to step {i + 1}") + reset_agent_state() + restore_agent_history(step=i + 1) + reset_env_history() + restore_environment() + st.rerun() + screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) + with screenshot_tab: + display_image(st.session_state.screenshot_history[i]) + with axtree_tab: + st.code(st.session_state.axtree_history[i], language=None, wrap_lines=True) + with prompt_tab: + st.code(st.session_state.prompt_history[i], language=None, wrap_lines=True) + st.markdown("**Thought**") + st.code(st.session_state.thought_history[i], language=None, wrap_lines=True) + st.markdown("**Action**") + st.code(st.session_state.action_history[i], language=None, wrap_lines=True) + + +def set_save_tab(): + # dump full session_state to json + save_dir = st.text_input("Save Directory", value="~/Downloads") + save_dir = os.path.expanduser(save_dir) + if st.button("Save Session State for Current Run"): + # save everything from the session in a way that is consistent + # with how experiments are saved with AgentLab + + # dir name has this format: 2025-07-14_16-46-47_tooluse-gpt-4-1-on-workarena-l1-task-name-sort + exp_dir = ( + Path(save_dir) + / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_genericagent_{st.session_state.agent_args.agent_name}_on_{st.session_state.benchmark}_{st.session_state.env_args.task_name}_{st.session_state.env_args.task_name}_{st.session_state.env_args.task_seed}" + ) + exp_dir.mkdir(parents=True, exist_ok=True) + + # save package versions + save_package_versions(exp_dir) + + # create ExpArgs object + exp_args = ExpArgs( + agent_args=st.session_state.agent_args, env_args=st.session_state.env_args + ) + with open(exp_dir / "exp_args.pkl", "wb") as f: + pickle.dump(exp_args, f) + + # create StepInfo object for each step + for i in range(len(st.session_state.action_history)): + step_info = StepInfo() + step_info.step = i + step_info.obs = st.session_state.obs_history[i] + step_info.reward = st.session_state.reward_history[i] + step_info.terminated = st.session_state.terminated_history[i] + step_info.truncated = st.session_state.truncated_history[i] + step_info.action = st.session_state.action_history[i] + step_info.agent_info = st.session_state.action_info_history[i] + step_info.make_stats() + # TODO: set profiling stats + step_info.task_info = st.session_state.env_info_history[i].get("task_info", None) + step_info.raw_reward = st.session_state.env_info_history[i].get( + "RAW_REWARD_GLOBAL", None + ) + step_info.save_step_info(exp_dir, save_screenshot=True, save_som=True) + + st.success(f"Saved session state at {exp_dir}") + + +def set_info_tabs(): + with st.container(border=True): + st.markdown("##### 🔎 Analyze") + # Display only if everything is now ready + if len(st.session_state.action_history) > 1: + screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab, save_tab = st.tabs( + ["Screenshot", "AxTree", "Prompt", "Previous Steps", "Save"] + ) + else: + screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) + + with screenshot_tab: + set_screenshot_tab() + with axtree_tab: + set_axtree_tab() + with prompt_tab: + set_prompt_tab() + if len(st.session_state.action_history) > 1: + with previous_steps_tab: + set_previous_steps_tab() + with save_tab: + set_save_tab() + + +def run_streamlit(): + + # config page + st.set_page_config( + page_title="AgentLab Controller", + page_icon="🕹️", + layout="wide", + initial_sidebar_state="collapsed", + ) + st.markdown( + '

🕹️ AgentLab Controller 🕹️

', unsafe_allow_html=True + ) + + setup_sidebar() + + set_session_state() + set_task_selector() + + if st.session_state.agent is not None: + if st.session_state.action is None: + get_action() + + set_controller() + set_info_tabs() + + +def main(): + run_streamlit() + + +if __name__ == "__main__": + main() diff --git a/src/agentlab/analyze/run_agentlab_controller.py b/src/agentlab/analyze/run_agentlab_controller.py new file mode 100644 index 00000000..5b472789 --- /dev/null +++ b/src/agentlab/analyze/run_agentlab_controller.py @@ -0,0 +1,14 @@ +from streamlit.web import cli + +from pathlib import Path + +CURR_DIR = Path(__file__).parent +agent_controller_path = CURR_DIR / "agent_controller.py" + + +def main(): + cli.main_run([str(agent_controller_path), "--server.port", "8501"]) + + +if __name__ == "__main__": + main() diff --git a/src/agentlab/analyze/server.py b/src/agentlab/analyze/server.py new file mode 100644 index 00000000..4f82f4c4 --- /dev/null +++ b/src/agentlab/analyze/server.py @@ -0,0 +1,588 @@ +# server.py +import base64 +import copy +import importlib +from typing import Any + +import dotenv +import numpy as np +import uvicorn + +# Import your BrowserEnv and any task setup you need +from bgym import DEFAULT_BENCHMARKS +from fastapi import FastAPI +from pydantic import BaseModel + +dotenv.load_dotenv() + +app = FastAPI() + + +def import_from_path(path: str) -> callable: + """ + Util function to import and instantiate a class, then return a specific method. + + Args: + path (str): Path to the method, e.g., 'browsergym.core.action.highlevel.HighLevelActionSet.to_python_code'. + + Raises: + ModuleNotFoundError: If the module cannot be imported. + + Returns: + callable: The method. + """ + + parts = path.split(".") + # Find the module (the longest prefix that can be imported) + for i in range(len(parts), 0, -1): + module_name = ".".join(parts[:i]) + try: + module = importlib.import_module(module_name) + break + except ModuleNotFoundError: + continue + else: + raise ModuleNotFoundError(f"Could not import module from path: {path}") + + obj = module + for attr in parts[i:]: + obj = getattr(obj, attr) + + # If the final object is a method, and its __qualname__ contains a class, instantiate the class + if callable(obj) and hasattr(obj, "__qualname__") and "." in obj.__qualname__: + class_name = obj.__qualname__.split(".")[0] + cls = getattr(module, class_name) + instance = cls() + method = getattr(instance, obj.__name__) + return method + + return obj + + +def make_json_safe(obj: Any) -> Any: + """ + Util function to convert numpy arrays and other non-JSON-serializable objects to JSON-serializable objects. + Specifically, we convert numpy arrays to base64 encoded strings so that payloads are of reasonable size. + + Args: + obj (Any): Object to convert + + Returns: + Any: JSON-serializable object + """ + if isinstance(obj, np.ndarray): + # convert to base64 + return { + "data": base64.b64encode(obj.tobytes()).decode("utf-8"), + "shape": obj.shape, + "dtype": str(obj.dtype), + } + elif isinstance(obj, dict): + return {k: make_json_safe(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [make_json_safe(v) for v in obj] + elif hasattr(obj, "__dict__"): + return make_json_safe(vars(obj)) + else: + return obj + + +# --- Models for requests --- +class SetInfoRequest(BaseModel): + benchmark_name: str + task_name: str + seed: int + action_mapping_fn: str + exp_dir: str + + +class StepRequest(BaseModel): + action: str + + +# --- Persistent Environment State --- +class EnvWrapper: + def __init__(self): + + # env info + self.benchmark_name = None + self.task_name = None + self.seed = None + self.action_mapping_fn = None + self.exp_dir = None + self.info_set = False + + # env state + self.env = None + self.last_obs = None + self.last_info = None + + # used to reload task + self.start_info = None + self.start_url = None + + def set_info( + self, + benchmark_name: str, + task_name: str, + seed: int, + action_mapping_fn: str, + exp_dir: str, + ) -> dict: + """ + Set the environment info. + + Args: + benchmark_name (str): Name of the benchmark + task_name (str): Name of the task + seed (int): Seed of the task + action_mapping_fn (str): Action mapping function + exp_dir (str): Directory for experiment + + Returns: + dict: Dictionary with status + """ + if self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info already set. Please unset the environment info first.", + } + ) + if self.env is not None: + return make_json_safe( + { + "status": "error", + "message": "Environment already created. Close the current environment proceeding.", + } + ) + self.benchmark_name = benchmark_name + self.task_name = task_name + self.seed = seed + self.action_mapping_fn = action_mapping_fn + self.exp_dir = exp_dir + self.info_set = True + + return make_json_safe( + { + "status": "success", + "message": "Environment info set successfully.", + } + ) + + def get_info(self) -> dict: + """ + Get the environment info. + + Returns: + dict: Dictionary with info + """ + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + return make_json_safe( + { + "status": "success", + "message": "Environment info retrieved successfully.", + "benchmark_name": self.benchmark_name, + "task_name": self.task_name, + "seed": self.seed, + "action_mapping_fn": self.action_mapping_fn, + "exp_dir": self.exp_dir, + } + ) + + def unset_info(self) -> dict: + """ + Unset the environment info. + + Returns: + dict: Dictionary with status + """ + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + self.info_set = False + self.benchmark_name = None + self.task_name = None + self.seed = None + self.action_mapping_fn = None + self.exp_dir = None + return make_json_safe( + { + "status": "success", + "message": "Environment info unset successfully.", + } + ) + + def status(self) -> dict: + """ + Get the environment status. + + Returns: + dict: Dictionary with status + """ + return make_json_safe( + { + "status": "success", + "message": "Environment status retrieved successfully.", + "obs": self.last_obs, + "reward": 0, + "terminated": False, + "truncated": False, + "info_set": self.info_set, + "env_created": self.env is not None, + } + ) + + def prepare_benchmark(self) -> dict: + """ + Prepare the benchmark environment. + + Returns: + dict: Dictionary with status + """ + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + + if self.env is not None: + # close the current environment first + self.env.close() + self.env = None + + # prepare backends + benchmark = DEFAULT_BENCHMARKS[self.benchmark_name]() + benchmark.env_args_list = [ + elem + for elem in benchmark.env_args_list + if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed) + ] + benchmark.prepare_backends() + + env_args = benchmark.env_args_list[0] + self.action_mapping = import_from_path(self.action_mapping_fn) + + # create environment + self.env = env_args.make_env(self.action_mapping, self.exp_dir) + return make_json_safe( + { + "status": "success", + "message": "Environment prepared successfully.", + } + ) + + def reload_task(self) -> dict: + """ + Reload the task. + + Returns: + dict: Dictionary with status + """ + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + elif not self.env: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + + # instead of resetting the whole environment, we go back to the original webpage and clear localStorage and sessionStorage + # NOTE: this is not guaranteed to result in the exact same state, but we find that it works most of the time, is much + # faster than resetting the whole environment, and ensures the seed of the environment remains the same + self.env.unwrapped.page.goto(self.start_url, wait_until="load") + self.env.unwrapped.page.evaluate( + "window.localStorage.clear(); window.sessionStorage.clear();" + ) + obs = self.env.unwrapped._get_obs() + + self.last_obs = copy.deepcopy(obs) + self.last_info = copy.deepcopy(self.start_info) + return make_json_safe( + { + "status": "success", + "message": "Task reloaded successfully.", + "obs": self.last_obs, + "reward": 0, + "terminated": False, + "truncated": False, + "info": self.last_info, + } + ) + + def reset(self) -> dict: + """ + Reset the environment. + + Returns: + dict: Dictionary with obs and info + """ + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + elif not self.env: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + + # reset the environment + obs, info = self.env.reset(seed=self.seed) + + self.last_obs = copy.deepcopy(obs) + self.last_info = copy.deepcopy(info) + self.start_info = copy.deepcopy(info) + self.start_url = copy.deepcopy(self.env.unwrapped.page.url) + return make_json_safe( + { + "status": "success", + "message": "Environment reset successfully", + "obs": self.last_obs, + "reward": 0, + "terminated": False, + "truncated": False, + "info": self.last_info, + } + ) + + def step(self, action: str) -> dict: + """ + Step the environment. + + Args: + action (str): Action to take + + Returns: + dict: Dictionary with obs, reward, terminated, truncated and info + """ + if self.env is None: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + # step the environment + obs, reward, terminated, truncated, info = self.env.step(action) + + self.last_obs = copy.deepcopy(obs) + self.last_info = copy.deepcopy(info) + return make_json_safe( + { + "status": "success", + "message": "Environment stepped successfully.", + "obs": obs, + "reward": reward, + "terminated": terminated, + "truncated": truncated, + "info": info, + } + ) + + def get_obs(self) -> dict: + """ + Get the last observation. + + Returns: + dict: Dictionary with obs and info + """ + if self.env is None: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + return make_json_safe( + { + "status": "success", + "message": "Observation retrieved successfully.", + "obs": self.last_obs, + "reward": 0, + "terminated": False, + "truncated": False, + "info": self.last_info, + } + ) + + def close(self) -> dict: + """ + Close the environment. + + Returns: + dict: Dictionary with status + """ + if self.env is None: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + self.env.close() + self.env = None + return make_json_safe( + { + "status": "success", + "message": "Environment closed successfully.", + } + ) + + +env = EnvWrapper() + + +# --- FastAPI endpoints --- +@app.post("/set_info") +def set_info(req: SetInfoRequest) -> dict: + """ + Set the environment info. + + Args: + req (SetInfoRequest): Request containing environment info + + Returns: + dict: Dictionary with status + """ + return env.set_info( + benchmark_name=req.benchmark_name, + task_name=req.task_name, + seed=req.seed, + action_mapping_fn=req.action_mapping_fn, + exp_dir=req.exp_dir, + ) + + +@app.get("/get_info") +def get_info() -> dict: + """ + Get the environment info. + + Returns: + dict: Dictionary with info + """ + return env.get_info() + + +@app.post("/unset_info") +def unset_info() -> dict: + """ + Unset the environment info. + + Returns: + dict: Dictionary with status + """ + return env.unset_info() + + +@app.get("/status") +def status() -> dict: + """ + Get the status of the environment. + + Returns: + dict: Dictionary with status + """ + return env.status() + + +@app.post("/prepare_benchmark") +def prepare_benchmark() -> dict: + """ + Prepare the benchmark. + + Returns: + dict: Dictionary with status + """ + return env.prepare_benchmark() + + +@app.post("/reload_task") +def reload_task() -> dict: + """ + Reload the task. + + Returns: + dict: Dictionary with status + """ + return env.reload_task() + + +@app.post("/reset") +def reset() -> dict: + """ + Reset the environment. + + Returns: + dict: Dictionary with status + """ + return env.reset() + + +@app.post("/step") +def step(req: StepRequest) -> dict: + """ + Step the environment. + + Args: + req (StepRequest): Request containing action + + Returns: + dict: Dictionary with obs, reward, terminated, truncated and info + """ + return env.step(action=req.action) + + +@app.get("/get_obs") +def get_obs() -> dict: + """ + Get the last observation. + + Returns: + dict: Dictionary with obs and info + """ + return env.get_obs() + + +@app.post("/close") +def close() -> dict: + """ + Close the environment. + + Returns: + dict: Dictionary with status + """ + return env.close() + + +def main(): + uvicorn.run("agentlab.analyze.server:app", host="127.0.0.1", port=8000, reload=True) + + +if __name__ == "__main__": + main() diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 12f1dd27..52ecbbe3 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -242,4 +242,33 @@ max_new_tokens=64_000, temperature=1e-1, ), + ### Azure + "azure/gpt-4o-mini": AzureModelArgs( + model_name="gpt-4o-mini", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), + "azure/gpt-4o": AzureModelArgs( + model_name="gpt-4o", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), + "azure/gpt-4.1": AzureModelArgs( + model_name="gpt-4.1", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), + "azure/gpt-4.1-mini": AzureModelArgs( + model_name="gpt-4.1-mini", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), } diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py index 1bbeeebc..0d998d7c 100644 --- a/src/agentlab/llm/response_api.py +++ b/src/agentlab/llm/response_api.py @@ -4,20 +4,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Literal, Optional, Union +from urllib.parse import urljoin import openai +from agentlab.llm.llm_utils import image_to_png_base64_url from anthropic import Anthropic from anthropic.types import Completion from anthropic.types import Message as AnthrophicMessage from openai import OpenAI -from agentlab.llm.llm_utils import image_to_png_base64_url - from .base_api import BaseModelArgs -from .llm_utils import ( - call_anthropic_api_with_retries, - call_openai_api_with_retries, -) +from .llm_utils import call_anthropic_api_with_retries, call_openai_api_with_retries from .tracking import TrackAPIPricingMixin """This module contains utlity classes for building input messages and interacting with LLM APIs. @@ -588,6 +585,32 @@ def _extract_env_actions_from_text_response( pass +class AzureOpenAIResponseModel(OpenAIResponseModel): + def __init__( + self, + model_name: str, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + temperature: float | None = None, + max_tokens: int | None = 100, + ): + api_key = os.getenv("AZURE_OPENAI_API_KEY") + base_url = urljoin(os.getenv("AZURE_OPENAI_ENDPOINT"), "openai/v1") + self.action_space_as_tools = True # this should be a config + super().__init__( # This is passed to BaseModel + model_name=model_name, api_key=api_key, temperature=temperature, max_tokens=max_tokens + ) + client_args = {} + if base_url is not None: + client_args["base_url"] = base_url + if api_key is not None: + client_args["api_key"] = api_key + client_args["default_query"] = {"api-version": "preview"} + self.client = OpenAI(**client_args) + # Init pricing tracker after super() so that all attributes have been set. + self.init_pricing_tracker(pricing_api="openai") # Use the PricingMixin + + class OpenAIChatCompletionModel(BaseModelWithPricing): def __init__( self, @@ -920,6 +943,21 @@ def get_message_builder(self) -> MessageBuilder: return OpenAIResponseAPIMessageBuilder +@dataclass +class AzureOpenAIResponseModelArgs(OpenAIResponseModelArgs): + """Serializable object for instantiating a generic chat model with an Azure OpenAI + model.""" + + api = "openai" + + def make_model(self, extra_kwargs=None, **kwargs): + return AzureOpenAIResponseModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + ) + + @dataclass class ClaudeResponseModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an OpenAI