Skip to content

Commit 35bc15b

Browse files
committed
Rework websocket
1 parent 93e6244 commit 35bc15b

File tree

5 files changed

+286
-130
lines changed

5 files changed

+286
-130
lines changed

python/e2b_code_interpreter/main.py

Lines changed: 45 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
import json
21
import threading
3-
import time
4-
import uuid
52
from concurrent.futures import Future
6-
from typing import Any, Callable, List, Optional
3+
from typing import Any, Callable, List, Optional, Dict
74

85
import requests
96
from e2b import EnvVars, ProcessMessage, Sandbox
107
from e2b.constants import TIMEOUT
11-
from websocket import create_connection
128

13-
from e2b_code_interpreter.models import Error, KernelException, Result
9+
from e2b_code_interpreter.messaging import JupyterKernelWebSocket
10+
from e2b_code_interpreter.models import KernelException, Result
1411

1512

1613
class CodeInterpreter(Sandbox):
@@ -40,15 +37,18 @@ def __init__(
4037
**kwargs,
4138
)
4239
self.notebook = JupyterExtension(self)
40+
# Close all the websocket connections when the interpreter is closed
41+
self._process_cleanup.append(self.notebook.close)
4342

4443

4544
class JupyterExtension:
4645
_default_kernel_id: Optional[str] = None
46+
_connected_kernels: Dict[str, JupyterKernelWebSocket] = {}
4747

4848
def __init__(self, sandbox: CodeInterpreter):
4949
self._sandbox = sandbox
5050
self._kernel_id_set = Future()
51-
self._set_default_kernel_id()
51+
self._start_connectiong_to_default_kernel()
5252

5353
def exec_cell(
5454
self,
@@ -58,12 +58,18 @@ def exec_cell(
5858
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None,
5959
) -> Result:
6060
kernel_id = kernel_id or self.default_kernel_id
61-
ws = self._connect_kernel(kernel_id)
62-
ws.send(json.dumps(self._send_execute_request(code)))
63-
result = self._wait_for_result(ws, on_stdout, on_stderr)
61+
ws = self._connected_kernels.get(kernel_id)
6462

65-
ws.close()
63+
if not ws:
64+
ws = JupyterKernelWebSocket(
65+
url=f"{self._sandbox.get_protocol('ws')}://{self._sandbox.get_hostname(8888)}/api/kernels/{kernel_id}/channels",
66+
)
67+
self._connected_kernels[kernel_id] = ws
68+
ws.connect()
69+
70+
session_id = ws.send_execution_message(code, on_stdout, on_stderr)
6671

72+
result = ws.get_result(session_id)
6773
return result
6874

6975
@property
@@ -73,31 +79,42 @@ def default_kernel_id(self) -> str:
7379

7480
return self._default_kernel_id
7581

76-
def create_kernel(self, timeout: Optional[float] = TIMEOUT) -> str:
82+
def create_kernel(self, cwd: Optional[str] = None,timeout: Optional[float] = TIMEOUT) -> str:
83+
data = {"cwd": cwd} if cwd else None
7784
response = requests.post(
7885
f"{self._sandbox.get_protocol()}://{self._sandbox.get_hostname(8888)}/api/kernels",
86+
json=data,
7987
timeout=timeout,
8088
)
8189
if not response.ok:
8290
raise KernelException(f"Failed to create kernel: {response.text}")
83-
return response.json()["id"]
91+
92+
kernel_id = response.json()["id"]
93+
94+
threading.Thread(target=self._connect_to_kernel_ws, args=kernel_id).start()
95+
return kernel_id
8496

8597
def restart_kernel(
8698
self, kernel_id: Optional[str] = None, timeout: Optional[float] = TIMEOUT
8799
) -> None:
88100
kernel_id = kernel_id or self.default_kernel_id
101+
102+
self._connected_kernels[kernel_id].close()
89103
response = requests.post(
90104
f"{self._sandbox.get_protocol()}://{self._sandbox.get_hostname(8888)}/api/kernels/{kernel_id}/restart",
91105
timeout=timeout,
92106
)
93107
if not response.ok:
94108
raise KernelException(f"Failed to restart kernel {kernel_id}")
95109

110+
threading.Thread(target=self._connect_to_kernel_ws, args=kernel_id).start()
111+
96112
def shutdown_kernel(
97113
self, kernel_id: Optional[str] = None, timeout: Optional[float] = TIMEOUT
98114
) -> None:
99115
kernel_id = kernel_id or self.default_kernel_id
100116

117+
self._connected_kernels[kernel_id].close()
101118
response = requests.delete(
102119
f"{self._sandbox.get_protocol()}://{self._sandbox.get_hostname(8888)}/api/kernels/{kernel_id}",
103120
timeout=timeout,
@@ -114,114 +131,21 @@ def list_kernels(self, timeout: Optional[float] = TIMEOUT) -> List[str]:
114131
raise KernelException(f"Failed to list kernels: {response.text}")
115132
return [kernel["id"] for kernel in response.json()]
116133

117-
def _set_default_kernel_id(self, timeout: Optional[float] = TIMEOUT) -> None:
118-
def set_kernel_id():
119-
self._kernel_id_set.set_result(
120-
self._sandbox.filesystem.read("/root/.jupyter/kernel_id", timeout=timeout).strip()
121-
)
122-
123-
threading.Thread(target=set_kernel_id).start()
134+
def close(self):
135+
for ws in self._connected_kernels.values():
136+
ws.close()
124137

125-
def _connect_kernel(self, kernel_id: str, timeout: Optional[float] = TIMEOUT):
126-
return create_connection(
127-
f"{self._sandbox.get_protocol('ws')}://{self._sandbox.get_hostname(8888)}/api/kernels/{kernel_id}/channels",
128-
timeout=timeout,
138+
def _connect_to_kernel_ws(self, kernel_id: str) -> None:
139+
ws = JupyterKernelWebSocket(
140+
url=f"{self._sandbox.get_protocol('ws')}://{self._sandbox.get_hostname(8888)}/api/kernels/{kernel_id}/channels",
129141
)
142+
ws.connect()
143+
self._connected_kernels[kernel_id] = ws
130144

131-
@staticmethod
132-
def _send_execute_request(code: str) -> dict:
133-
msg_id = str(uuid.uuid4())
134-
session = str(uuid.uuid4())
135-
136-
return {
137-
"header": {
138-
"msg_id": msg_id,
139-
"username": "e2b",
140-
"session": session,
141-
"msg_type": "execute_request",
142-
"version": "5.3",
143-
},
144-
"parent_header": {},
145-
"metadata": {},
146-
"content": {
147-
"code": code,
148-
"silent": False,
149-
"store_history": False,
150-
"user_expressions": {},
151-
"allow_stdin": False,
152-
},
153-
}
154-
155-
@staticmethod
156-
def _wait_for_result(
157-
ws,
158-
on_stdout: Optional[Callable[[ProcessMessage], Any]],
159-
on_stderr: Optional[Callable[[ProcessMessage], Any]],
160-
) -> Result:
161-
result = Result()
162-
input_accepted = False
163-
164-
while True:
165-
response = json.loads(ws.recv())
166-
if response["msg_type"] == "error":
167-
result.error = Error(
168-
name=response["content"]["ename"],
169-
value=response["content"]["evalue"],
170-
traceback=response["content"]["traceback"],
171-
)
172-
173-
elif response["msg_type"] == "stream":
174-
if response["content"]["name"] == "stdout":
175-
result.stdout.append(response["content"]["text"])
176-
if on_stdout:
177-
on_stdout(
178-
ProcessMessage(
179-
line=response["content"]["text"],
180-
timestamp=time.time_ns(),
181-
)
182-
)
183-
184-
elif response["content"]["name"] == "stderr":
185-
result.stderr.append(response["content"]["text"])
186-
if on_stderr:
187-
on_stderr(
188-
ProcessMessage(
189-
line=response["content"]["text"],
190-
error=True,
191-
timestamp=time.time_ns(),
192-
)
193-
)
194-
195-
elif response["msg_type"] == "display_data":
196-
result.display_data.append(response["content"]["data"])
197-
198-
elif response["msg_type"] == "execute_result":
199-
result.output = response["content"]["data"]["text/plain"]
200-
201-
elif response["msg_type"] == "status":
202-
if response["content"]["execution_state"] == "idle":
203-
if input_accepted:
204-
break
205-
elif response["content"]["execution_state"] == "error":
206-
result.error = Error(
207-
name=response["content"]["ename"],
208-
value=response["content"]["evalue"],
209-
traceback=response["content"]["traceback"],
210-
)
211-
break
212-
213-
elif response["msg_type"] == "execute_reply":
214-
if response["content"]["status"] == "error":
215-
result.error = Error(
216-
name=response["content"]["ename"],
217-
value=response["content"]["evalue"],
218-
traceback=response["content"]["traceback"],
219-
)
220-
elif response["content"]["status"] == "ok":
221-
pass
222-
223-
elif response["msg_type"] == "execute_input":
224-
input_accepted = True
225-
else:
226-
print("[UNHANDLED MESSAGE TYPE]:", response["msg_type"])
227-
return result
145+
def _start_connectiong_to_default_kernel(self, timeout: Optional[float] = TIMEOUT) -> None:
146+
def setup_default_kernel():
147+
kernel_id = self._sandbox.filesystem.read("/root/.jupyter/kernel_id", timeout=timeout).strip()
148+
self._connect_to_kernel_ws(kernel_id)
149+
self._kernel_id_set.set_result(kernel_id)
150+
151+
threading.Thread(target=setup_default_kernel).start()

0 commit comments

Comments
 (0)