diff --git a/owl/utils/enhanced_role_playing.py b/owl/utils/enhanced_role_playing.py
index 0c4754cee..2c7202d57 100644
--- a/owl/utils/enhanced_role_playing.py
+++ b/owl/utils/enhanced_role_playing.py
@@ -13,6 +13,7 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Dict, List, Optional, Tuple
+import threading
from camel.agents import ChatAgent
@@ -38,6 +39,8 @@ def __init__(self, **kwargs):
self.assistant_agent_kwargs: dict = kwargs.get("assistant_agent_kwargs", {})
self.output_language = kwargs.get("output_language", None)
+
+ self.stop_event = kwargs.get("stop_event", None)
super().__init__(**kwargs)
@@ -62,6 +65,7 @@ def __init__(self, **kwargs):
user_agent_kwargs=self.user_agent_kwargs,
output_language=self.output_language,
# is_reasoning_task=self.is_reasoning_task
+ stop_event=self.stop_event
)
def _init_agents(
@@ -72,6 +76,7 @@ def _init_agents(
user_agent_kwargs: Optional[Dict] = None,
output_language: Optional[str] = None,
is_reasoning_task: bool = False,
+ stop_event: Optional[threading.Event] = None,
) -> None:
r"""Initialize assistant and user agents with their system messages.
@@ -86,6 +91,9 @@ def _init_agents(
pass to the user agent. (default: :obj:`None`)
output_language (str, optional): The language to be output by the
agents. (default: :obj:`None`)
+ stop_event (Optional[threading.Event], optional): Event to signal
+ termination of the agent's operation. When set, the agent will
+ terminate its execution. (default: :obj:`None`)
"""
if self.model is not None:
if assistant_agent_kwargs is None:
@@ -107,6 +115,7 @@ def _init_agents(
self.assistant_agent = ChatAgent(
init_assistant_sys_msg,
output_language=output_language,
+ stop_event=stop_event,
**(assistant_agent_kwargs or {}),
)
self.assistant_sys_msg = self.assistant_agent.system_message
@@ -114,6 +123,7 @@ def _init_agents(
self.user_agent = ChatAgent(
init_user_sys_msg,
output_language=output_language,
+ stop_event=stop_event,
**(user_agent_kwargs or {}),
)
self.user_sys_msg = self.user_agent.system_message
@@ -217,12 +227,8 @@ def step(
user_response = self.user_agent.step(assistant_msg)
if user_response.terminated or user_response.msgs is None:
return (
- ChatAgentResponse(msgs=[], terminated=False, info={}),
- ChatAgentResponse(
- msgs=[],
- terminated=user_response.terminated,
- info=user_response.info,
- ),
+ ChatAgentResponse(msgs=[assistant_msg], terminated=False, info={}),
+ user_response
)
user_msg = self._reduce_message_options(user_response.msgs)
@@ -247,13 +253,9 @@ def step(
assistant_response = self.assistant_agent.step(modified_user_msg)
if assistant_response.terminated or assistant_response.msgs is None:
return (
+ assistant_response,
ChatAgentResponse(
- msgs=[],
- terminated=assistant_response.terminated,
- info=assistant_response.info,
- ),
- ChatAgentResponse(
- msgs=[user_msg], terminated=False, info=user_response.info
+ msgs=[modified_user_msg], terminated=False, info=user_response.info
),
)
assistant_msg = self._reduce_message_options(assistant_response.msgs)
@@ -436,10 +438,10 @@ def step(
),
)
-
def run_society(
society: OwlRolePlaying,
round_limit: int = 15,
+ stop_event: Optional[threading.Event] = None
) -> Tuple[str, List[dict], dict]:
overall_completion_token_count = 0
overall_prompt_token_count = 0
@@ -448,58 +450,83 @@ def run_society(
init_prompt = """
Now please give me instructions to solve over overall task step by step. If the task requires some specific knowledge, please instruct me to use tools to complete the task.
"""
- input_msg = society.init_chat(init_prompt)
- for _round in range(round_limit):
- assistant_response, user_response = society.step(input_msg)
- # Check if usage info is available before accessing it
- if assistant_response.info.get("usage") and user_response.info.get("usage"):
- overall_completion_token_count += assistant_response.info["usage"].get(
- "completion_tokens", 0
- ) + user_response.info["usage"].get("completion_tokens", 0)
- overall_prompt_token_count += assistant_response.info["usage"].get(
- "prompt_tokens", 0
- ) + user_response.info["usage"].get("prompt_tokens", 0)
-
- # convert tool call to dict
- tool_call_records: List[dict] = []
- if assistant_response.info.get("tool_calls"):
- for tool_call in assistant_response.info["tool_calls"]:
- tool_call_records.append(tool_call.as_dict())
-
- _data = {
- "user": user_response.msg.content
- if hasattr(user_response, "msg") and user_response.msg
- else "",
- "assistant": assistant_response.msg.content
- if hasattr(assistant_response, "msg") and assistant_response.msg
- else "",
- "tool_calls": tool_call_records,
- }
+ society.stop_event = stop_event
+
+ try:
+ input_msg = society.init_chat(init_prompt)
+ for _round in range(round_limit):
+ assistant_response, user_response = society.step(input_msg)
+ # Check if usage info is available before accessing it
+ if assistant_response.info.get("usage") and user_response.info.get("usage"):
+ overall_completion_token_count += assistant_response.info["usage"].get(
+ "completion_tokens", 0
+ ) + user_response.info["usage"].get("completion_tokens", 0)
+ overall_prompt_token_count += assistant_response.info["usage"].get(
+ "prompt_tokens", 0
+ ) + user_response.info["usage"].get("prompt_tokens", 0)
+
+ # convert tool call to dict
+ tool_call_records: List[dict] = []
+ if assistant_response.info.get("tool_calls"):
+ for tool_call in assistant_response.info["tool_calls"]:
+ tool_call_records.append(tool_call.as_dict())
+
+ _data = {
+ "user": user_response.msg.content
+ if hasattr(user_response, "msg") and user_response.msg
+ else "",
+ "assistant": assistant_response.msg.content
+ if hasattr(assistant_response, "msg") and assistant_response.msg
+ else "",
+ "tool_calls": tool_call_records,
+ }
+
+ chat_history.append(_data)
+ logger.info(
+ f"Round #{_round} user_response:\n {user_response.msgs[0].content if user_response.msgs and len(user_response.msgs) > 0 else ''}"
+ )
+ logger.info(
+ f"Round #{_round} assistant_response:\n {assistant_response.msgs[0].content if assistant_response.msgs and len(assistant_response.msgs) > 0 else ''}"
+ )
- chat_history.append(_data)
- logger.info(
- f"Round #{_round} user_response:\n {user_response.msgs[0].content if user_response.msgs and len(user_response.msgs) > 0 else ''}"
- )
- logger.info(
- f"Round #{_round} assistant_response:\n {assistant_response.msgs[0].content if assistant_response.msgs and len(assistant_response.msgs) > 0 else ''}"
- )
+ if (
+ assistant_response.terminated
+ or user_response.terminated
+ or "TASK_DONE" in user_response.msg.content
+ or (stop_event and stop_event.is_set())
+ ):
+ break
- if (
- assistant_response.terminated
- or user_response.terminated
- or "TASK_DONE" in user_response.msg.content
- ):
- break
-
- input_msg = assistant_response.msg
+ input_msg = assistant_response.msg
- answer = chat_history[-1]["assistant"]
- token_info = {
- "completion_token_count": overall_completion_token_count,
- "prompt_token_count": overall_prompt_token_count,
- }
+ answer = chat_history[-1]["assistant"] if chat_history else ""
+ token_info = {
+ "completion_token_count": overall_completion_token_count,
+ "prompt_token_count": overall_prompt_token_count,
+ }
- return answer, chat_history, token_info
+ return answer, chat_history, token_info
+
+ except Exception as e:
+ logger.error(f"Exception in run_society: {e}")
+ # Add empty results for proper return type in case of error
+ answer = f"Error: {str(e)}"
+ token_info = {
+ "completion_token_count": overall_completion_token_count,
+ "prompt_token_count": overall_prompt_token_count,
+ }
+ # Re-raise after cleanup
+ raise
+
+ finally:
+ # Always attempt to terminate browser, regardless of how we exit the function
+ if hasattr(society, 'assistant_agent') and hasattr(society.assistant_agent, 'tool_dict') and society.assistant_agent.tool_dict and 'terminate_browser' in society.assistant_agent.tool_dict:
+ try:
+ flag, msg = society.assistant_agent.tool_dict['terminate_browser']()
+ logger.info(f"Browser termination result: success={flag}, message='{msg}'")
+ except Exception as term_error:
+ logger.error(f"Failed to terminate browser: {term_error}")
+ # We don't re-raise browser termination errors to ensure the original error (if any) is preserved
async def arun_society(
diff --git a/owl/webapp.py b/owl/webapp.py
index 33f06b679..6171ced53 100644
--- a/owl/webapp.py
+++ b/owl/webapp.py
@@ -78,7 +78,12 @@ def setup_logging():
STOP_LOG_THREAD = threading.Event()
CURRENT_PROCESS = None # Used to track the currently running process
STOP_REQUESTED = threading.Event() # Used to mark if stop was requested
-
+STATE = {
+ "token_count": "0",
+ "status": (f" Ready"),
+ "logs": "No conversation records yet.",
+ "running": False
+}
# Log reading and updating functions
def log_reader_thread(log_file):
@@ -323,7 +328,7 @@ def run_owl(question: str, example_module: str) -> Tuple[str, str, str]:
Returns:
Tuple[...]: Answer, token count, status
"""
- global CURRENT_PROCESS
+ global CURRENT_PROCESS, STOP_REQUESTED
# Validate input
if not validate_input(question):
@@ -395,11 +400,22 @@ def run_owl(question: str, example_module: str) -> Tuple[str, str, str]:
"0",
f"❌ Error: Build failed - {str(e)}",
)
+
+ # Check if STOP_REQUESTED. Early Premption when triggered early
+ if STOP_REQUESTED and STOP_REQUESTED.is_set():
+ return (
+ f"Thread Returned Early due to termination",
+ "0",
+ "☑️ Success - OWL Stopped",
+ )
# Run society simulation
try:
logging.info("Running society simulation...")
- answer, chat_history, token_info = run_society(society)
+ answer, chat_history, token_info = run_society(
+ society=society,
+ stop_event=STOP_REQUESTED
+ )
logging.info("Society simulation completed")
except Exception as e:
logging.error(f"Error occurred while running society simulation: {str(e)}")
@@ -433,6 +449,25 @@ def run_owl(question: str, example_module: str) -> Tuple[str, str, str]:
)
return (f"Error occurred: {str(e)}", "0", f"❌ Error: {str(e)}")
+def stop_owl() -> None:
+ r"""
+ Trigger the STOP_REQUESTED Event to Stop OWL and update the app state
+
+ Returns:
+ None
+ """
+ global CURRENT_PROCESS, STOP_REQUESTED, STATE
+ msg_template = lambda msg: (f" {msg}")
+
+ if STOP_REQUESTED.is_set() and CURRENT_PROCESS.is_alive():
+ STATE["status"] = msg_template("Termination in the process...")
+
+ if CURRENT_PROCESS and CURRENT_PROCESS.is_alive():
+ STOP_REQUESTED.set() # Signal the thread to stop
+ logging.info("STOP_REQUESTED Event is Set")
+ STATE["status"] = msg_template("Stopping the society...")
+ else:
+ STATE["status"] = msg_template("Process already completed.")
def update_module_description(module_name: str) -> str:
"""Return the description of the selected module"""
@@ -798,9 +833,14 @@ def clear_log_file():
return ""
# Create a real-time log update function
- def process_with_live_logs(question, module_name):
- """Process questions and update logs in real-time"""
- global CURRENT_PROCESS
+ import asyncio
+ async def process_with_live_logs(question, module_name) -> Tuple[gr.Button, gr.Button]:
+ r"""Start Owl in Thread and update logs in realtime
+
+ Returns:
+ Tuple[...]: Optimistically toggle the state of the button
+ """
+ global CURRENT_PROCESS, STATE
# Clear log file
clear_log_file()
@@ -822,47 +862,76 @@ def process_in_background():
CURRENT_PROCESS = bg_thread # Record current process
bg_thread.start()
- # While waiting for processing to complete, update logs once per second
- while bg_thread.is_alive():
- # Update conversation record display
- logs2 = get_latest_logs(100, LOG_QUEUE)
-
- # Always update status
- yield (
- "0",
- " Processing...",
- logs2,
- )
-
- time.sleep(1)
-
- # Processing complete, get results
- if not result_queue.empty():
- result = result_queue.get()
- answer, token_count, status = result
-
- # Final update of conversation record
- logs2 = get_latest_logs(100, LOG_QUEUE)
+ async def update_logs_async(result_queue, bg_thread, STATE) -> None:
+ r"""Updates the realtime logs in async with a new asyncio task
+
+ Args:
+ result_queue: The Queue updated by run_owl(). Contains answer, token_count & Status
+ bg_thread: The Background thread the run_owl() is running at
+ STATE: The current app state which is a global dictionary of data
+ """
+ while bg_thread.is_alive():
+ STATE["logs"] = get_latest_logs(100, LOG_QUEUE)
+ STATE["token_count"] = "0" # Example update
+ STATE["status"] = (f" Processing...")
+ STATE["running"] = True
+
+ await asyncio.sleep(1) # Allow UI updates
+ # Processing complete, get results
+ if not result_queue.empty():
+ logging.info("Real time logs finished ✅")
+ result = result_queue.get()
+ answer, token_count, status = result
+ # Final update of conversation record
+ logs2 = get_latest_logs(100, LOG_QUEUE)
+ # Set different indicators based on status
+ if "Error" in status:
+ status_with_indicator = (
+ f" {status}"
+ )
+ else:
+ status_with_indicator = (
+ f" {status}"
+ )
- # Set different indicators based on status
- if "Error" in status:
- status_with_indicator = (
- f" {status}"
- )
+ STATE["logs"] = logs2
+ STATE["status"] = status_with_indicator
+ STATE["token_count"] = token_count # Example update
+ STATE["running"] = False
+
+ #Revert the Task Flag
+ STOP_REQUESTED.clear()
else:
- status_with_indicator = (
- f" {status}"
- )
-
- yield token_count, status_with_indicator, logs2
- else:
- logs2 = get_latest_logs(100, LOG_QUEUE)
- yield (
- "0",
- " Terminated",
- logs2,
- )
-
+ logs2 = get_latest_logs(100, LOG_QUEUE)
+ gr.update()
+
+ STATE["logs"] = "0"
+ STATE["status"] =" Terminated"
+ STATE["token_count"] = logs2
+ STATE["running"] = False
+
+ # Start a separate async task for updating logs
+ asyncio.create_task(update_logs_async(result_queue, bg_thread, STATE))
+
+ # Optimistic Toggle of Start Button
+ return (gr.Button(visible=False), gr.Button(visible=True))
+
+ def update_interface() -> Tuple[str,str,str,gr.Button,gr.Button]:
+ r"""Update the latest state values.
+
+ Returns:
+ Tuple[...]: Links output to token_count_output, status_output, log_display2, run_button, stop_button
+ """
+ global STATE
+
+ return (
+ STATE["token_count"],
+ STATE["status"],
+ STATE["logs"],
+ gr.Button(visible=not STATE["running"]), # run_button
+ gr.Button(visible=STATE["running"]) # stop_button
+ )
+
with gr.Blocks(title="OWL", theme=gr.themes.Soft(primary_hue="blue")) as app:
gr.Markdown(
"""
@@ -1113,6 +1182,8 @@ def process_in_background():
run_button = gr.Button(
"Run", variant="primary", elem_classes="primary"
)
+ # Stop button (hidden initially)
+ stop_button = gr.Button("Stop", variant="secondary", visible=False)
status_output = gr.HTML(
value=" Ready",
@@ -1242,10 +1313,17 @@ def process_in_background():
refresh_button.click(fn=update_env_table, outputs=[env_table])
# Set up event handling
- run_button.click(
+ start_event = run_button.click(
fn=process_with_live_logs,
inputs=[question_input, module_dropdown],
- outputs=[token_count_output, status_output, log_display2],
+ outputs=[run_button, stop_button],
+ queue=True
+ )
+ # When clicking the stop button, stop the background thread and show start button
+ stop_button.click(
+ fn=stop_owl,
+ queue=True,
+ cancels=start_event
)
# Module selection updates description
@@ -1275,7 +1353,7 @@ def toggle_auto_refresh(enabled):
outputs=[log_display2],
)
- # No longer automatically refresh logs by default
+ app.load(update_interface, outputs=[token_count_output, status_output, log_display2, run_button, stop_button], every=1)
return app