Skip to content

Commit c0937d0

Browse files
committed
Use code execution status rather than chat message status.
1 parent 0d13de2 commit c0937d0

File tree

7 files changed

+303
-141
lines changed

7 files changed

+303
-141
lines changed

open-webui/functions/run_code.py

Lines changed: 94 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ async def action(
163163
valves = self.valves
164164
debug = valves.DEBUG
165165
emitter = EventEmitter(__event_emitter__, debug=debug)
166+
execution_tracker: typing.Optional[CodeExecutionTracker] = None
166167

167168
update_check_error = None
168169
update_check_notice = ""
@@ -194,6 +195,9 @@ async def action(
194195
)
195196

196197
async def _fail(error_message, status="SANDBOX_ERROR"):
198+
if execution_tracker is not None:
199+
execution_tracker.set_error(error_message)
200+
await emitter.code_execution(execution_tracker)
197201
if debug:
198202
await emitter.fail(
199203
f"[DEBUG MODE] {error_message}; body={body}; valves=[{valves}]"
@@ -204,7 +208,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
204208
await emitter.fail(error_message)
205209
return json.dumps({"status": status, "output": error_message})
206210

207-
await emitter.status("Checking messages for code blocks...")
208211
if len(body.get("messages", ())) == 0:
209212
return await _fail("No messages in conversation.", status="INVALID_INPUT")
210213
last_message = body["messages"][-1]
@@ -271,7 +274,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
271274
if self.valves.MAX_RAM_MEGABYTES != 0:
272275
max_ram_bytes = self.valves.MAX_RAM_MEGABYTES * 1024 * 1024
273276

274-
await emitter.status("Checking if environment supports sandboxing...")
275277
Sandbox.check_setup(
276278
language=language,
277279
auto_install_allowed=self.valves.AUTO_INSTALL,
@@ -282,7 +284,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
282284
await emitter.status("Auto-installing gVisor...")
283285
Sandbox.install_runsc()
284286

285-
await emitter.status("Initializing sandbox configuration...")
286287
status = "UNKNOWN"
287288
output = None
288289
generated_files = []
@@ -298,6 +299,12 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
298299
code = code.removeprefix("bash")
299300
code = code.removeprefix("sh")
300301
code = code.strip()
302+
language_title = language.title()
303+
execution_tracker = CodeExecutionTracker(
304+
name=f"{language_title} code block", code=code, language=language
305+
)
306+
await emitter.clear_status()
307+
await emitter.code_execution(execution_tracker)
301308

302309
with tempfile.TemporaryDirectory(prefix="sandbox_") as tmp_dir:
303310
sandbox_storage_path = os.path.join(tmp_dir, "storage")
@@ -314,23 +321,25 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
314321
persistent_home_dir=sandbox_storage_path,
315322
)
316323

317-
await emitter.status(
318-
f"Running {language_title} code in gVisor sandbox..."
319-
)
320324
try:
321325
result = sandbox.run()
322326
except Sandbox.ExecutionTimeoutError as e:
323327
await emitter.fail(
324328
f"Code timed out after {valves.MAX_RUNTIME_SECONDS} seconds"
325329
)
330+
execution_tracker.set_error(
331+
f"Code timed out after {valves.MAX_RUNTIME_SECONDS} seconds"
332+
)
326333
status = "TIMEOUT"
327334
output = e.stderr
328335
except Sandbox.InterruptedExecutionError as e:
329336
await emitter.fail("Code used too many resources")
337+
execution_tracker.set_error("Code used too many resources")
330338
status = "INTERRUPTED"
331339
output = e.stderr
332340
except Sandbox.CodeExecutionError as e:
333341
await emitter.fail(f"{language_title}: {e}")
342+
execution_tracker.set_error(f"{language_title}: {e}")
334343
status = "ERROR"
335344
output = e.stderr
336345
else:
@@ -354,14 +363,14 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
354363
status = "STORAGE_ERROR"
355364
output = f"Storage quota exceeded: {e}"
356365
await emitter.fail(output)
357-
if status == "OK":
358-
await emitter.status(
359-
status="complete",
360-
done=True,
361-
description=f"{language_title} code executed successfully.",
362-
)
366+
for generated_file in generated_files:
367+
execution_tracker.add_file(
368+
name=generated_file.name, url=generated_file.url
369+
)
363370
if output:
364371
output = output.strip()
372+
execution_tracker.set_output(output)
373+
await emitter.code_execution(execution_tracker)
365374
if debug:
366375
per_file_logs = {}
367376

@@ -935,6 +944,9 @@ async def action(
935944
)
936945

937946

947+
# fmt: off
948+
949+
938950
class EventEmitter:
939951
"""
940952
Helper wrapper for OpenWebUI event emissions.
@@ -948,27 +960,32 @@ def __init__(
948960
self.event_emitter = event_emitter
949961
self._debug = debug
950962
self._status_prefix = None
963+
self._emitted_status = False
951964

952965
def set_status_prefix(self, status_prefix):
953966
self._status_prefix = status_prefix
954967

955-
async def _emit(self, typ, data):
968+
async def _emit(self, typ, data, twice):
956969
if self._debug:
957970
print(f"Emitting {typ} event: {data}", file=sys.stderr)
958971
if not self.event_emitter:
959972
return None
960-
maybe_future = self.event_emitter(
961-
{
962-
"type": typ,
963-
"data": data,
964-
}
965-
)
966-
if asyncio.isfuture(maybe_future) or inspect.isawaitable(maybe_future):
967-
return await maybe_future
973+
result = None
974+
for i in range(2 if twice else 1):
975+
maybe_future = self.event_emitter(
976+
{
977+
"type": typ,
978+
"data": data,
979+
}
980+
)
981+
if asyncio.isfuture(maybe_future) or inspect.isawaitable(maybe_future):
982+
result = await maybe_future
983+
return result
968984

969985
async def status(
970986
self, description="Unknown state", status="in_progress", done=False
971987
):
988+
self._emitted_status = True
972989
if self._status_prefix is not None:
973990
description = f"{self._status_prefix}{description}"
974991
await self._emit(
@@ -978,29 +995,33 @@ async def status(
978995
"description": description,
979996
"done": done,
980997
},
998+
twice=not done and len(description) <= 1024,
981999
)
982-
if not done and len(description) <= 1024:
983-
# Emit it again; Open WebUI does not seem to flush this reliably.
984-
# Only do it for relatively small statuses; when debug mode is enabled,
985-
# this can take up a lot of space.
986-
await self._emit(
987-
"status",
988-
{
989-
"status": status,
990-
"description": description,
991-
"done": done,
992-
},
993-
)
9941000

9951001
async def fail(self, description="Unknown error"):
9961002
await self.status(description=description, status="error", done=True)
9971003

1004+
async def clear_status(self):
1005+
if not self._emitted_status:
1006+
return
1007+
self._emitted_status = False
1008+
await self._emit(
1009+
"status",
1010+
{
1011+
"status": "complete",
1012+
"description": "",
1013+
"done": True,
1014+
},
1015+
twice=True,
1016+
)
1017+
9981018
async def message(self, content):
9991019
await self._emit(
10001020
"message",
10011021
{
10021022
"content": content,
10031023
},
1024+
twice=False,
10041025
)
10051026

10061027
async def citation(self, document, metadata, source):
@@ -1011,16 +1032,51 @@ async def citation(self, document, metadata, source):
10111032
"metadata": metadata,
10121033
"source": source,
10131034
},
1035+
twice=False,
10141036
)
10151037

1016-
async def code_execution_result(self, output):
1038+
async def code_execution(self, code_execution_tracker):
10171039
await self._emit(
1018-
"code_execution_result",
1040+
"citation", code_execution_tracker._citation_data(), twice=True
1041+
)
1042+
1043+
1044+
class CodeExecutionTracker:
1045+
def __init__(self, name, code, language):
1046+
self._uuid = str(uuid.uuid4())
1047+
self.name = name
1048+
self.code = code
1049+
self.language = language
1050+
self._result = {}
1051+
1052+
def set_error(self, error):
1053+
self._result["error"] = error
1054+
1055+
def set_output(self, output):
1056+
self._result["output"] = output
1057+
1058+
def add_file(self, name, url):
1059+
if "files" not in self._result:
1060+
self._result["files"] = []
1061+
self._result["files"].append(
10191062
{
1020-
"output": output,
1021-
},
1063+
"name": name,
1064+
"url": url,
1065+
}
10221066
)
10231067

1068+
def _citation_data(self):
1069+
data = {
1070+
"type": "code_execution",
1071+
"uuid": self._uuid,
1072+
"name": self.name,
1073+
"code": self.code,
1074+
"language": self.language,
1075+
}
1076+
if "output" in self._result or "error" in self._result:
1077+
data["result"] = self._result
1078+
return data
1079+
10241080

10251081
class Sandbox:
10261082
"""
@@ -3255,6 +3311,7 @@ def get_newer_version(cls) -> typing.Optional[str]:
32553311

32563312

32573313
UpdateCheck.init_from_frontmatter(os.path.abspath(__file__))
3314+
# fmt: on
32583315

32593316

32603317
_SAMPLE_BASH_INSTRUCTIONS = (

0 commit comments

Comments
 (0)