Skip to content

Commit 0f29797

Browse files
committed
Use code execution status rather than chat message status.
1 parent 14a6f3d commit 0f29797

File tree

7 files changed

+303
-142
lines changed

7 files changed

+303
-142
lines changed

open-webui/functions/run_code.py

Lines changed: 94 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ async def action(
165165
valves = self.valves
166166
debug = valves.DEBUG
167167
emitter = EventEmitter(__event_emitter__, debug=debug)
168+
execution_tracker: typing.Optional[CodeExecutionTracker] = None
168169

169170
update_check_error = None
170171
update_check_notice = ""
@@ -196,6 +197,9 @@ async def action(
196197
)
197198

198199
async def _fail(error_message, status="SANDBOX_ERROR"):
200+
if execution_tracker is not None:
201+
execution_tracker.set_error(error_message)
202+
await emitter.code_execution(execution_tracker)
199203
if debug:
200204
await emitter.fail(
201205
f"[DEBUG MODE] {error_message}; body={body}; valves=[{valves}]"
@@ -206,7 +210,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
206210
await emitter.fail(error_message)
207211
return json.dumps({"status": status, "output": error_message})
208212

209-
await emitter.status("Checking messages for code blocks...")
210213
if len(body.get("messages", ())) == 0:
211214
return await _fail("No messages in conversation.", status="INVALID_INPUT")
212215
last_message = body["messages"][-1]
@@ -273,7 +276,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
273276
if self.valves.MAX_RAM_MEGABYTES != 0:
274277
max_ram_bytes = self.valves.MAX_RAM_MEGABYTES * 1024 * 1024
275278

276-
await emitter.status("Checking if environment supports sandboxing...")
277279
Sandbox.check_setup(
278280
language=language,
279281
auto_install_allowed=self.valves.AUTO_INSTALL,
@@ -284,7 +286,6 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
284286
await emitter.status("Auto-installing gVisor...")
285287
Sandbox.install_runsc()
286288

287-
await emitter.status("Initializing sandbox configuration...")
288289
status = "UNKNOWN"
289290
output = None
290291
generated_files = []
@@ -300,6 +301,12 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
300301
code = code.removeprefix("bash")
301302
code = code.removeprefix("sh")
302303
code = code.strip()
304+
language_title = language.title()
305+
execution_tracker = CodeExecutionTracker(
306+
name=f"{language_title} code block", code=code, language=language
307+
)
308+
await emitter.clear_status()
309+
await emitter.code_execution(execution_tracker)
303310

304311
with tempfile.TemporaryDirectory(prefix="sandbox_") as tmp_dir:
305312
sandbox_storage_path = os.path.join(tmp_dir, "storage")
@@ -316,23 +323,25 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
316323
persistent_home_dir=sandbox_storage_path,
317324
)
318325

319-
await emitter.status(
320-
f"Running {language_title} code in gVisor sandbox..."
321-
)
322326
try:
323327
result = sandbox.run()
324328
except Sandbox.ExecutionTimeoutError as e:
325329
await emitter.fail(
326330
f"Code timed out after {valves.MAX_RUNTIME_SECONDS} seconds"
327331
)
332+
execution_tracker.set_error(
333+
f"Code timed out after {valves.MAX_RUNTIME_SECONDS} seconds"
334+
)
328335
status = "TIMEOUT"
329336
output = e.stderr
330337
except Sandbox.InterruptedExecutionError as e:
331338
await emitter.fail("Code used too many resources")
339+
execution_tracker.set_error("Code used too many resources")
332340
status = "INTERRUPTED"
333341
output = e.stderr
334342
except Sandbox.CodeExecutionError as e:
335343
await emitter.fail(f"{language_title}: {e}")
344+
execution_tracker.set_error(f"{language_title}: {e}")
336345
status = "ERROR"
337346
output = e.stderr
338347
else:
@@ -356,14 +365,14 @@ async def _fail(error_message, status="SANDBOX_ERROR"):
356365
status = "STORAGE_ERROR"
357366
output = f"Storage quota exceeded: {e}"
358367
await emitter.fail(output)
359-
if status == "OK":
360-
await emitter.status(
361-
status="complete",
362-
done=True,
363-
description=f"{language_title} code executed successfully.",
364-
)
368+
for generated_file in generated_files:
369+
execution_tracker.add_file(
370+
name=generated_file.name, url=generated_file.url
371+
)
365372
if output:
366373
output = output.strip()
374+
execution_tracker.set_output(output)
375+
await emitter.code_execution(execution_tracker)
367376
if debug:
368377
per_file_logs = {}
369378

@@ -954,6 +963,9 @@ async def action(
954963
)
955964

956965

966+
# fmt: off
967+
968+
957969
class EventEmitter:
958970
"""
959971
Helper wrapper for OpenWebUI event emissions.
@@ -967,27 +979,32 @@ def __init__(
967979
self.event_emitter = event_emitter
968980
self._debug = debug
969981
self._status_prefix = None
982+
self._emitted_status = False
970983

971984
def set_status_prefix(self, status_prefix):
972985
self._status_prefix = status_prefix
973986

974-
async def _emit(self, typ, data):
987+
async def _emit(self, typ, data, twice):
975988
if self._debug:
976989
print(f"Emitting {typ} event: {data}", file=sys.stderr)
977990
if not self.event_emitter:
978991
return None
979-
maybe_future = self.event_emitter(
980-
{
981-
"type": typ,
982-
"data": data,
983-
}
984-
)
985-
if asyncio.isfuture(maybe_future) or inspect.isawaitable(maybe_future):
986-
return await maybe_future
992+
result = None
993+
for i in range(2 if twice else 1):
994+
maybe_future = self.event_emitter(
995+
{
996+
"type": typ,
997+
"data": data,
998+
}
999+
)
1000+
if asyncio.isfuture(maybe_future) or inspect.isawaitable(maybe_future):
1001+
result = await maybe_future
1002+
return result
9871003

9881004
async def status(
9891005
self, description="Unknown state", status="in_progress", done=False
9901006
):
1007+
self._emitted_status = True
9911008
if self._status_prefix is not None:
9921009
description = f"{self._status_prefix}{description}"
9931010
await self._emit(
@@ -997,29 +1014,33 @@ async def status(
9971014
"description": description,
9981015
"done": done,
9991016
},
1017+
twice=not done and len(description) <= 1024,
10001018
)
1001-
if not done and len(description) <= 1024:
1002-
# Emit it again; Open WebUI does not seem to flush this reliably.
1003-
# Only do it for relatively small statuses; when debug mode is enabled,
1004-
# this can take up a lot of space.
1005-
await self._emit(
1006-
"status",
1007-
{
1008-
"status": status,
1009-
"description": description,
1010-
"done": done,
1011-
},
1012-
)
10131019

10141020
async def fail(self, description="Unknown error"):
10151021
await self.status(description=description, status="error", done=True)
10161022

1023+
async def clear_status(self):
1024+
if not self._emitted_status:
1025+
return
1026+
self._emitted_status = False
1027+
await self._emit(
1028+
"status",
1029+
{
1030+
"status": "complete",
1031+
"description": "",
1032+
"done": True,
1033+
},
1034+
twice=True,
1035+
)
1036+
10171037
async def message(self, content):
10181038
await self._emit(
10191039
"message",
10201040
{
10211041
"content": content,
10221042
},
1043+
twice=False,
10231044
)
10241045

10251046
async def citation(self, document, metadata, source):
@@ -1030,16 +1051,51 @@ async def citation(self, document, metadata, source):
10301051
"metadata": metadata,
10311052
"source": source,
10321053
},
1054+
twice=False,
10331055
)
10341056

1035-
async def code_execution_result(self, output):
1057+
async def code_execution(self, code_execution_tracker):
10361058
await self._emit(
1037-
"code_execution_result",
1059+
"citation", code_execution_tracker._citation_data(), twice=True
1060+
)
1061+
1062+
1063+
class CodeExecutionTracker:
1064+
def __init__(self, name, code, language):
1065+
self._uuid = str(uuid.uuid4())
1066+
self.name = name
1067+
self.code = code
1068+
self.language = language
1069+
self._result = {}
1070+
1071+
def set_error(self, error):
1072+
self._result["error"] = error
1073+
1074+
def set_output(self, output):
1075+
self._result["output"] = output
1076+
1077+
def add_file(self, name, url):
1078+
if "files" not in self._result:
1079+
self._result["files"] = []
1080+
self._result["files"].append(
10381081
{
1039-
"output": output,
1040-
},
1082+
"name": name,
1083+
"url": url,
1084+
}
10411085
)
10421086

1087+
def _citation_data(self):
1088+
data = {
1089+
"type": "code_execution",
1090+
"id": self._uuid,
1091+
"name": self.name,
1092+
"code": self.code,
1093+
"language": self.language,
1094+
}
1095+
if "output" in self._result or "error" in self._result:
1096+
data["result"] = self._result
1097+
return data
1098+
10431099

10441100
class Sandbox:
10451101
"""
@@ -3661,6 +3717,7 @@ def get_newer_version(cls) -> typing.Optional[str]:
36613717

36623718

36633719
UpdateCheck.init_from_frontmatter(os.path.abspath(__file__))
3720+
# fmt: on
36643721

36653722

36663723
_SAMPLE_BASH_INSTRUCTIONS = (

0 commit comments

Comments
 (0)