1
- import json
2
1
import threading
3
- import time
4
- import uuid
5
2
from concurrent .futures import Future
6
- from typing import Any , Callable , List , Optional
3
+ from typing import Any , Callable , List , Optional , Dict
7
4
8
5
import requests
9
6
from e2b import EnvVars , ProcessMessage , Sandbox
10
7
from e2b .constants import TIMEOUT
11
- from websocket import create_connection
12
8
13
- from e2b_code_interpreter .models import Error , KernelException , Result
9
+ from e2b_code_interpreter .messaging import JupyterKernelWebSocket
10
+ from e2b_code_interpreter .models import KernelException , Result
14
11
15
12
16
13
class CodeInterpreter (Sandbox ):
@@ -40,15 +37,18 @@ def __init__(
40
37
** kwargs ,
41
38
)
42
39
self .notebook = JupyterExtension (self )
40
+ # Close all the websocket connections when the interpreter is closed
41
+ self ._process_cleanup .append (self .notebook .close )
43
42
44
43
45
44
class JupyterExtension :
46
45
_default_kernel_id : Optional [str ] = None
46
+ _connected_kernels : Dict [str , JupyterKernelWebSocket ] = {}
47
47
48
48
def __init__ (self , sandbox : CodeInterpreter ):
49
49
self ._sandbox = sandbox
50
50
self ._kernel_id_set = Future ()
51
- self ._set_default_kernel_id ()
51
+ self ._start_connectiong_to_default_kernel ()
52
52
53
53
def exec_cell (
54
54
self ,
@@ -58,12 +58,18 @@ def exec_cell(
58
58
on_stderr : Optional [Callable [[ProcessMessage ], Any ]] = None ,
59
59
) -> Result :
60
60
kernel_id = kernel_id or self .default_kernel_id
61
- ws = self ._connect_kernel (kernel_id )
62
- ws .send (json .dumps (self ._send_execute_request (code )))
63
- result = self ._wait_for_result (ws , on_stdout , on_stderr )
61
+ ws = self ._connected_kernels .get (kernel_id )
64
62
65
- ws .close ()
63
+ if not ws :
64
+ ws = JupyterKernelWebSocket (
65
+ url = f"{ self ._sandbox .get_protocol ('ws' )} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } /channels" ,
66
+ )
67
+ self ._connected_kernels [kernel_id ] = ws
68
+ ws .connect ()
69
+
70
+ session_id = ws .send_execution_message (code , on_stdout , on_stderr )
66
71
72
+ result = ws .get_result (session_id )
67
73
return result
68
74
69
75
@property
@@ -73,31 +79,42 @@ def default_kernel_id(self) -> str:
73
79
74
80
return self ._default_kernel_id
75
81
76
- def create_kernel (self , timeout : Optional [float ] = TIMEOUT ) -> str :
82
+ def create_kernel (self , cwd : Optional [str ] = None ,timeout : Optional [float ] = TIMEOUT ) -> str :
83
+ data = {"cwd" : cwd } if cwd else None
77
84
response = requests .post (
78
85
f"{ self ._sandbox .get_protocol ()} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels" ,
86
+ json = data ,
79
87
timeout = timeout ,
80
88
)
81
89
if not response .ok :
82
90
raise KernelException (f"Failed to create kernel: { response .text } " )
83
- return response .json ()["id" ]
91
+
92
+ kernel_id = response .json ()["id" ]
93
+
94
+ threading .Thread (target = self ._connect_to_kernel_ws , args = kernel_id ).start ()
95
+ return kernel_id
84
96
85
97
def restart_kernel (
86
98
self , kernel_id : Optional [str ] = None , timeout : Optional [float ] = TIMEOUT
87
99
) -> None :
88
100
kernel_id = kernel_id or self .default_kernel_id
101
+
102
+ self ._connected_kernels [kernel_id ].close ()
89
103
response = requests .post (
90
104
f"{ self ._sandbox .get_protocol ()} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } /restart" ,
91
105
timeout = timeout ,
92
106
)
93
107
if not response .ok :
94
108
raise KernelException (f"Failed to restart kernel { kernel_id } " )
95
109
110
+ threading .Thread (target = self ._connect_to_kernel_ws , args = kernel_id ).start ()
111
+
96
112
def shutdown_kernel (
97
113
self , kernel_id : Optional [str ] = None , timeout : Optional [float ] = TIMEOUT
98
114
) -> None :
99
115
kernel_id = kernel_id or self .default_kernel_id
100
116
117
+ self ._connected_kernels [kernel_id ].close ()
101
118
response = requests .delete (
102
119
f"{ self ._sandbox .get_protocol ()} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } " ,
103
120
timeout = timeout ,
@@ -114,114 +131,21 @@ def list_kernels(self, timeout: Optional[float] = TIMEOUT) -> List[str]:
114
131
raise KernelException (f"Failed to list kernels: { response .text } " )
115
132
return [kernel ["id" ] for kernel in response .json ()]
116
133
117
- def _set_default_kernel_id (self , timeout : Optional [float ] = TIMEOUT ) -> None :
118
- def set_kernel_id ():
119
- self ._kernel_id_set .set_result (
120
- self ._sandbox .filesystem .read ("/root/.jupyter/kernel_id" , timeout = timeout ).strip ()
121
- )
122
-
123
- threading .Thread (target = set_kernel_id ).start ()
134
+ def close (self ):
135
+ for ws in self ._connected_kernels .values ():
136
+ ws .close ()
124
137
125
- def _connect_kernel (self , kernel_id : str , timeout : Optional [float ] = TIMEOUT ):
126
- return create_connection (
127
- f"{ self ._sandbox .get_protocol ('ws' )} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } /channels" ,
128
- timeout = timeout ,
138
+ def _connect_to_kernel_ws (self , kernel_id : str ) -> None :
139
+ ws = JupyterKernelWebSocket (
140
+ url = f"{ self ._sandbox .get_protocol ('ws' )} ://{ self ._sandbox .get_hostname (8888 )} /api/kernels/{ kernel_id } /channels" ,
129
141
)
142
+ ws .connect ()
143
+ self ._connected_kernels [kernel_id ] = ws
130
144
131
- @staticmethod
132
- def _send_execute_request (code : str ) -> dict :
133
- msg_id = str (uuid .uuid4 ())
134
- session = str (uuid .uuid4 ())
135
-
136
- return {
137
- "header" : {
138
- "msg_id" : msg_id ,
139
- "username" : "e2b" ,
140
- "session" : session ,
141
- "msg_type" : "execute_request" ,
142
- "version" : "5.3" ,
143
- },
144
- "parent_header" : {},
145
- "metadata" : {},
146
- "content" : {
147
- "code" : code ,
148
- "silent" : False ,
149
- "store_history" : False ,
150
- "user_expressions" : {},
151
- "allow_stdin" : False ,
152
- },
153
- }
154
-
155
- @staticmethod
156
- def _wait_for_result (
157
- ws ,
158
- on_stdout : Optional [Callable [[ProcessMessage ], Any ]],
159
- on_stderr : Optional [Callable [[ProcessMessage ], Any ]],
160
- ) -> Result :
161
- result = Result ()
162
- input_accepted = False
163
-
164
- while True :
165
- response = json .loads (ws .recv ())
166
- if response ["msg_type" ] == "error" :
167
- result .error = Error (
168
- name = response ["content" ]["ename" ],
169
- value = response ["content" ]["evalue" ],
170
- traceback = response ["content" ]["traceback" ],
171
- )
172
-
173
- elif response ["msg_type" ] == "stream" :
174
- if response ["content" ]["name" ] == "stdout" :
175
- result .stdout .append (response ["content" ]["text" ])
176
- if on_stdout :
177
- on_stdout (
178
- ProcessMessage (
179
- line = response ["content" ]["text" ],
180
- timestamp = time .time_ns (),
181
- )
182
- )
183
-
184
- elif response ["content" ]["name" ] == "stderr" :
185
- result .stderr .append (response ["content" ]["text" ])
186
- if on_stderr :
187
- on_stderr (
188
- ProcessMessage (
189
- line = response ["content" ]["text" ],
190
- error = True ,
191
- timestamp = time .time_ns (),
192
- )
193
- )
194
-
195
- elif response ["msg_type" ] == "display_data" :
196
- result .display_data .append (response ["content" ]["data" ])
197
-
198
- elif response ["msg_type" ] == "execute_result" :
199
- result .output = response ["content" ]["data" ]["text/plain" ]
200
-
201
- elif response ["msg_type" ] == "status" :
202
- if response ["content" ]["execution_state" ] == "idle" :
203
- if input_accepted :
204
- break
205
- elif response ["content" ]["execution_state" ] == "error" :
206
- result .error = Error (
207
- name = response ["content" ]["ename" ],
208
- value = response ["content" ]["evalue" ],
209
- traceback = response ["content" ]["traceback" ],
210
- )
211
- break
212
-
213
- elif response ["msg_type" ] == "execute_reply" :
214
- if response ["content" ]["status" ] == "error" :
215
- result .error = Error (
216
- name = response ["content" ]["ename" ],
217
- value = response ["content" ]["evalue" ],
218
- traceback = response ["content" ]["traceback" ],
219
- )
220
- elif response ["content" ]["status" ] == "ok" :
221
- pass
222
-
223
- elif response ["msg_type" ] == "execute_input" :
224
- input_accepted = True
225
- else :
226
- print ("[UNHANDLED MESSAGE TYPE]:" , response ["msg_type" ])
227
- return result
145
+ def _start_connectiong_to_default_kernel (self , timeout : Optional [float ] = TIMEOUT ) -> None :
146
+ def setup_default_kernel ():
147
+ kernel_id = self ._sandbox .filesystem .read ("/root/.jupyter/kernel_id" , timeout = timeout ).strip ()
148
+ self ._connect_to_kernel_ws (kernel_id )
149
+ self ._kernel_id_set .set_result (kernel_id )
150
+
151
+ threading .Thread (target = setup_default_kernel ).start ()
0 commit comments