Skip to content

Commit 66e01e6

Browse files
committed
code exec interface + preload
1 parent da8651e commit 66e01e6

File tree

8 files changed

+62
-19
lines changed

8 files changed

+62
-19
lines changed

preload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ async def preload_kokoro():
4040
# async tasks to preload
4141
tasks = [
4242
preload_embedding(),
43-
preload_whisper(),
44-
preload_kokoro()
43+
# preload_whisper(),
44+
# preload_kokoro()
4545
]
4646

4747
await asyncio.gather(*tasks, return_exceptions=True)

python/api/synthesize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
class Synthesize(ApiHandler):
88
async def process(self, input: dict, request: Request) -> dict | Response:
99
text = input.get("text", "")
10-
ctxid = input.get("ctxid", "")
10+
# ctxid = input.get("ctxid", "")
1111

12-
context = self.get_context(ctxid)
13-
if not await kokoro_tts.is_downloaded():
14-
context.log.log(type="info", content="Kokoro TTS model is currently being initialized, please wait...")
12+
# context = self.get_context(ctxid)
13+
# if not await kokoro_tts.is_downloaded():
14+
# context.log.log(type="info", content="Kokoro TTS model is currently being initialized, please wait...")
1515

1616
try:
1717
# # Clean and chunk text for long responses

python/api/transcribe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
class Transcribe(ApiHandler):
66
async def process(self, input: dict, request: Request) -> dict | Response:
77
audio = input.get("audio")
8-
ctxid = input.get("ctxid", "")
8+
# ctxid = input.get("ctxid", "")
99

10-
context = self.get_context(ctxid)
11-
if not await whisper.is_downloaded():
12-
context.log.log(type="info", content="Whisper STT model is currently being initialized, please wait...")
10+
# context = self.get_context(ctxid)
11+
# if not await whisper.is_downloaded():
12+
# context.log.log(type="info", content="Whisper STT model is currently being initialized, please wait...")
1313

1414
set = settings.get_settings()
1515
result = await whisper.transcribe(set["stt_model_size"], audio) # type: ignore

python/helpers/fasta2a_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def _configure(self):
258258
# Atomic update of the app
259259
self.app = new_app
260260

261-
_PRINTER.print("[A2A] FastA2A server configured successfully")
261+
# _PRINTER.print("[A2A] FastA2A server configured successfully")
262262

263263
except Exception as e:
264264
_PRINTER.print(f"[A2A] Failed to configure FastA2A server: {e}")

python/helpers/kokoro_tts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import soundfile as sf
88
from python.helpers import runtime
99
from python.helpers.print_style import PrintStyle
10+
from python.helpers.notification import NotificationManager, NotificationType, NotificationPriority
1011

1112
warnings.filterwarnings("ignore", category=FutureWarning)
1213
warnings.filterwarnings("ignore", category=UserWarning)
@@ -38,9 +39,21 @@ async def _preload():
3839
try:
3940
is_updating_model = True
4041
if not _pipeline:
42+
NotificationManager.send_notification(
43+
NotificationType.INFO,
44+
NotificationPriority.NORMAL,
45+
"Loading Kokoro TTS model...",
46+
display_time=99,
47+
group="kokoro-preload")
4148
PrintStyle.standard("Loading Kokoro TTS model...")
4249
from kokoro import KPipeline
4350
_pipeline = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M")
51+
NotificationManager.send_notification(
52+
NotificationType.INFO,
53+
NotificationPriority.NORMAL,
54+
"Kokoro TTS model loaded.",
55+
display_time=2,
56+
group="kokoro-preload")
4457
finally:
4558
is_updating_model = False
4659

python/helpers/settings.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class Settings(TypedDict):
8484
rfc_port_http: int
8585
rfc_port_ssh: int
8686

87+
shell_interface: Literal['local','ssh']
88+
8789
stt_model_size: str
8890
stt_language: str
8991
stt_silence_threshold: float
@@ -793,6 +795,17 @@ def convert_out(settings: Settings) -> SettingsOutput:
793795

794796
dev_fields: list[SettingsField] = []
795797

798+
dev_fields.append(
799+
{
800+
"id": "shell_interface",
801+
"title": "Shell Interface",
802+
"description": "Terminal interface used for Code Execution Tool. Local Python TTY works locally in both dockerized and development environments. SSH always connects to dockerized environment (automatically at localhost or RFC host address).",
803+
"type": "select",
804+
"value": settings["shell_interface"],
805+
"options": [{"value": "local", "label": "Local Python TTY"}, {"value": "ssh", "label": "SSH"}],
806+
}
807+
)
808+
796809
if runtime.is_development():
797810
# dev_fields.append(
798811
# {
@@ -1378,6 +1391,7 @@ def get_default_settings() -> Settings:
13781391
rfc_password="",
13791392
rfc_port_http=55080,
13801393
rfc_port_ssh=55022,
1394+
shell_interface="local" if runtime.is_dockerized() else "ssh",
13811395
stt_model_size="base",
13821396
stt_language="en",
13831397
stt_silence_threshold=0.3,
@@ -1539,7 +1553,7 @@ def set_root_password(password: str):
15391553
def get_runtime_config(set: Settings):
15401554
if runtime.is_dockerized():
15411555
return {
1542-
"code_exec_ssh_enabled": False,
1556+
"code_exec_ssh_enabled": set["shell_interface"] == "ssh",
15431557
"code_exec_ssh_addr": "localhost",
15441558
"code_exec_ssh_port": 22,
15451559
"code_exec_ssh_user": "root",
@@ -1553,7 +1567,7 @@ def get_runtime_config(set: Settings):
15531567
if host.endswith("/"):
15541568
host = host[:-1]
15551569
return {
1556-
"code_exec_ssh_enabled": True,
1570+
"code_exec_ssh_enabled": set["shell_interface"] == "ssh",
15571571
"code_exec_ssh_addr": host,
15581572
"code_exec_ssh_port": set["rfc_port_ssh"],
15591573
"code_exec_ssh_user": "root",

python/helpers/whisper.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
from python.helpers import runtime, rfc, settings, files
77
from python.helpers.print_style import PrintStyle
8+
from python.helpers.notification import NotificationManager, NotificationType, NotificationPriority
89

910
# Suppress FutureWarning from torch.load
1011
warnings.filterwarnings("ignore", category=FutureWarning)
@@ -30,9 +31,21 @@ async def _preload(model_name:str):
3031
try:
3132
is_updating_model = True
3233
if not _model or _model_name != model_name:
33-
PrintStyle.standard(f"Loading Whisper model: {model_name}")
34-
_model = whisper.load_model(name=model_name, download_root=files.get_abs_path("/tmp/models/whisper")) # type: ignore
35-
_model_name = model_name
34+
NotificationManager.send_notification(
35+
NotificationType.INFO,
36+
NotificationPriority.NORMAL,
37+
"Loading Whisper model...",
38+
display_time=99,
39+
group="whisper-preload")
40+
PrintStyle.standard(f"Loading Whisper model: {model_name}")
41+
_model = whisper.load_model(name=model_name, download_root=files.get_abs_path("/tmp/models/whisper")) # type: ignore
42+
_model_name = model_name
43+
NotificationManager.send_notification(
44+
NotificationType.INFO,
45+
NotificationPriority.NORMAL,
46+
"Whisper model loaded.",
47+
display_time=2,
48+
group="whisper-preload")
3649
finally:
3750
is_updating_model = False
3851

python/tools/code_execution_tool.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
@dataclass
1717
class State:
18+
ssh_enabled: bool
1819
shells: dict[int, LocalInteractiveSession | SSHInteractiveSession]
1920

2021

@@ -77,7 +78,8 @@ async def after_execution(self, response, **kwargs):
7778

7879
async def prepare_state(self, reset=False, session: int | None = None):
7980
self.state: State | None = self.agent.get_data("_cet_state")
80-
if not self.state:
81+
# always reset state when ssh_enabled changes
82+
if not self.state or self.state.ssh_enabled != self.agent.config.code_exec_ssh_enabled:
8183
# initialize shells dictionary if not exists
8284
shells: dict[int, LocalInteractiveSession | SSHInteractiveSession] = {}
8385
else:
@@ -114,7 +116,7 @@ async def prepare_state(self, reset=False, session: int | None = None):
114116
shells[session] = shell
115117
await shell.connect()
116118

117-
self.state = State(shells=shells)
119+
self.state = State(shells=shells, ssh_enabled=self.agent.config.code_exec_ssh_enabled)
118120
self.agent.set_data("_cet_state", self.state)
119121
return self.state
120122

@@ -201,9 +203,10 @@ async def get_terminal_output(
201203

202204
# Common shell prompt regex patterns (add more as needed)
203205
prompt_patterns = [
204-
re.compile(r"\\(venv\\).+[$#] ?$"), # (venv) ...$ or (venv) ...#
206+
re.compile(r"\(venv\).+[$#] ?$"), # (venv) ...$ or (venv) ...#
205207
re.compile(r"root@[^:]+:[^#]+# ?$"), # root@container:~#
206208
re.compile(r"[a-zA-Z0-9_.-]+@[^:]+:[^$#]+[$#] ?$"), # user@host:~$
209+
re.compile(r"bash-\d+\.\d+\$ ?$"), # bash-3.2$ (version can vary)
207210
]
208211

209212
# potential dialog detection

0 commit comments

Comments
 (0)