diff --git a/browsergym/core/src/browsergym/core/action/base.py b/browsergym/core/src/browsergym/core/action/base.py index 6f06303b..06a866f8 100644 --- a/browsergym/core/src/browsergym/core/action/base.py +++ b/browsergym/core/src/browsergym/core/action/base.py @@ -38,6 +38,7 @@ def execute_python_code( code: str, page: playwright.sync_api.Page, send_message_to_user: callable, + add_observation: callable, report_infeasible_instructions: callable, ): """ @@ -56,6 +57,7 @@ def execute_python_code( globals = { "page": page, "send_message_to_user": send_message_to_user, + "add_observation": add_observation, "report_infeasible_instructions": report_infeasible_instructions, "DEMO_MODE": get_global_demo_mode(), } diff --git a/browsergym/core/src/browsergym/core/action/functions.py b/browsergym/core/src/browsergym/core/action/functions.py index bb31db9a..0107ed30 100644 --- a/browsergym/core/src/browsergym/core/action/functions.py +++ b/browsergym/core/src/browsergym/core/action/functions.py @@ -14,6 +14,7 @@ page: playwright.sync_api.Page = None send_message_to_user: callable = None +add_observation: callable = None report_infeasible_instructions: callable = None demo_mode: Literal["off", "default", "all_blue", "only_visible_elements"] = None retry_with_force: bool = False @@ -33,6 +34,8 @@ def send_msg_to_user(text: str): """ send_message_to_user(text) +def add_observation(obs: dict): + pass def report_infeasible(reason: str): """ @@ -622,3 +625,26 @@ def mouse_upload_file(x: float, y: float, file: str | list[str]): file_chooser = fc_info.value file_chooser.set_files(file) + +def get_element_html(bid: str): + """ + Returns the HTML of an element identified by its bid. + + Examples: + get_element_html('123') + """ + elem = get_elem_by_bid(page, bid, demo_mode != "off") + if elem: + outer_html_content = elem.evaluate('elem => elem.outerHTML') + #send_msg_to_user(f"The HTML of the element with bid {bid} is:\n--- START ---\n" + outer_html_content + "\n--- END ---\n") + add_observation({ + "type": "generic", + "html": outer_html_content, + "bid": bid, + }) + else: + #send_msg_to_user("The element with bid " + bid + " does not exist") + add_observation({ + "type": "generic", + "error": f"The element with bid {bid} doesn't exist" + }) diff --git a/browsergym/core/src/browsergym/core/action/highlevel.py b/browsergym/core/src/browsergym/core/action/highlevel.py index da2c539c..b1eed5ea 100644 --- a/browsergym/core/src/browsergym/core/action/highlevel.py +++ b/browsergym/core/src/browsergym/core/action/highlevel.py @@ -38,6 +38,7 @@ tab_close, tab_focus, upload_file, + get_element_html, ) from .parsers import action_docstring_parser, highlevel_action_parser @@ -59,6 +60,7 @@ clear, drag_and_drop, upload_file, + get_element_html, ], "coord": [ scroll, diff --git a/browsergym/core/src/browsergym/core/env.py b/browsergym/core/src/browsergym/core/env.py index 30b565ba..dde9ac3b 100644 --- a/browsergym/core/src/browsergym/core/env.py +++ b/browsergym/core/src/browsergym/core/env.py @@ -76,6 +76,7 @@ def __init__( pw_context_kwargs: dict = {}, # agent-related arguments action_mapping: Optional[callable] = HighLevelActionSet().to_python_code, + obs: dict = None, ): """ Instantiate a ready to use BrowserEnv gym environment. @@ -384,6 +385,11 @@ def send_message_to_user(text: str): raise ValueError(f"Forbidden value: {text} is not a string") self.chat.add_message(role="assistant", msg=text) + def add_observation(obs: dict): + if not isinstance(obs, dict): + raise ValueError(f"Forbidden value: {obj} is not a dict") + self.obs = obs + def report_infeasible_instructions(reason: str): if not isinstance(reason, str): raise ValueError(f"Forbidden value: {reason} is not a string") @@ -392,6 +398,7 @@ def report_infeasible_instructions(reason: str): # try to execute the action logger.debug(f"Executing action") + self.obs = None try: if self.action_mapping: code = self.action_mapping(action) @@ -401,6 +408,7 @@ def report_infeasible_instructions(reason: str): code, self.page, send_message_to_user=send_message_to_user, + add_observation=add_observation, report_infeasible_instructions=report_infeasible_instructions, ) self.last_action_error = "" @@ -438,8 +446,11 @@ def report_infeasible_instructions(reason: str): if user_message: self.chat.add_message(role="user", msg=user_message) - # extract observation (generic) - obs = self._get_obs() + if self.obs: + obs = self.obs + else: + # extract observation (generic) + obs = self._get_obs() logger.debug(f"Observation extracted") # new step API wants a 5-tuple (gymnasium)