Skip to content

Commit 974b1b0

Browse files
committed
Change from on display data to on result
1 parent bc0bec9 commit 974b1b0

File tree

9 files changed

+146
-49
lines changed

9 files changed

+146
-49
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ print("world")
149149
"""
150150

151151
with CodeInterpreter() as sandbox:
152-
sandbox.notebook.exec_cell(code, on_stdout=print, on_stderr=print, on_display_data=(lambda data: print(data.text)))
152+
sandbox.notebook.exec_cell(code, on_stdout=print, on_stderr=print, on_result=(lambda result: print(result.text)))
153153

154154
```
155155

@@ -175,7 +175,7 @@ const sandbox = await CodeInterpreter.create()
175175
await sandbox.notebook.execCell(code, {
176176
onStdout: (out) => console.log(out),
177177
onStderr: (outErr) => console.error(outErr),
178-
onDisplayData: (outData) => console.log(outData.text)
178+
onResult: (result) => console.log(result.text)
179179
})
180180

181181
await sandbox.close()

js/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ const sandbox = await CodeInterpreter.create()
8080
await sandbox.notebook.execCell(code, {
8181
onStdout: (out) => console.log(out),
8282
onStderr: (outErr) => console.error(outErr),
83-
onDisplayData: (outData) => console.log(outData.text)
83+
onResult: (result) => console.log(result.text)
8484
})
8585

8686
await sandbox.close()

js/src/code-interpreter.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ export class JupyterExtension {
6363
* @param kernelID The ID of the kernel to execute the code on. If not provided, the default kernel is used.
6464
* @param onStdout A callback function to handle standard output messages from the code execution.
6565
* @param onStderr A callback function to handle standard error messages from the code execution.
66-
* @param onDisplayData A callback function to handle display data messages from the code execution.
66+
* @param onResult A callback function to handle display data messages from the code execution.
6767
* @param timeout The maximum time to wait for the code execution to complete, in milliseconds.
6868
* @returns A promise that resolves with the result of the code execution.
6969
*/
@@ -73,13 +73,13 @@ export class JupyterExtension {
7373
kernelID,
7474
onStdout,
7575
onStderr,
76-
onDisplayData,
76+
onResult,
7777
timeout
7878
}: {
7979
kernelID?: string
80-
onStdout?: (msg: ProcessMessage) => Promise<void> | void
81-
onStderr?: (msg: ProcessMessage) => Promise<void> | void
82-
onDisplayData?: (data: Result) => Promise<void> | void
80+
onStdout?: (msg: ProcessMessage) => any
81+
onStderr?: (msg: ProcessMessage) => any
82+
onResult?: (data: Result) => any
8383
timeout?: number
8484
} = {}
8585
): Promise<Execution> {
@@ -92,7 +92,7 @@ export class JupyterExtension {
9292
code,
9393
onStdout,
9494
onStderr,
95-
onDisplayData,
95+
onResult,
9696
timeout
9797
)
9898
}

js/src/messaging.ts

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,20 @@ export class Execution {
186186
*/
187187
class CellExecution {
188188
execution: Execution
189-
onStdout?: (out: ProcessMessage) => Promise<void> | void
190-
onStderr?: (out: ProcessMessage) => Promise<void> | void
191-
onDisplayData?: (data: Result) => Promise<void> | void
189+
onStdout?: (out: ProcessMessage) => any
190+
onStderr?: (out: ProcessMessage) => any
191+
onResult?: (data: Result) => any
192192
inputAccepted: boolean = false
193193

194194
constructor(
195-
onStdout?: (out: ProcessMessage) => Promise<void> | void,
196-
onStderr?: (out: ProcessMessage) => Promise<void> | void,
197-
onDisplayData?: (data: Result) => Promise<void> | void
195+
onStdout?: (out: ProcessMessage) => any,
196+
onStderr?: (out: ProcessMessage) => any,
197+
onResult?: (data: Result) => any
198198
) {
199199
this.execution = new Execution([], { stdout: [], stderr: [] })
200200
this.onStdout = onStdout
201201
this.onStderr = onStderr
202-
this.onDisplayData = onDisplayData
202+
this.onResult = onResult
203203
}
204204
}
205205

@@ -295,11 +295,15 @@ export class JupyterKernelWebSocket {
295295
} else if (message.msg_type == 'display_data') {
296296
const result = new Result(message.content.data, false)
297297
execution.results.push(result)
298-
if (cell.onDisplayData) {
299-
cell.onDisplayData(result)
298+
if (cell.onResult) {
299+
cell.onResult(result)
300300
}
301301
} else if (message.msg_type == 'execute_result') {
302-
execution.results.push(new Result(message.content.data, true))
302+
const result = new Result(message.content.data, true)
303+
execution.results.push(result)
304+
if (cell.onResult) {
305+
cell.onResult(result)
306+
}
303307
} else if (message.msg_type == 'status') {
304308
if (message.content.execution_state == 'idle') {
305309
if (cell.inputAccepted) {
@@ -337,15 +341,15 @@ export class JupyterKernelWebSocket {
337341
* @param code Code to be executed.
338342
* @param onStdout Callback for stdout messages.
339343
* @param onStderr Callback for stderr messages.
340-
* @param onDisplayData Callback for display data messages.
344+
* @param onResult Callback function to handle the result and display calls of the code execution.
341345
* @param timeout Time in milliseconds to wait for response.
342346
* @returns Promise with execution result.
343347
*/
344348
public sendExecutionMessage(
345349
code: string,
346-
onStdout?: (out: ProcessMessage) => Promise<void> | void,
347-
onStderr?: (out: ProcessMessage) => Promise<void> | void,
348-
onDisplayData?: (data: Result) => Promise<void> | void,
350+
onStdout?: (out: ProcessMessage) => any,
351+
onStderr?: (out: ProcessMessage) => any,
352+
onResult?: (data: Result) => any,
349353
timeout?: number
350354
) {
351355
return new Promise<Execution>((resolve, reject) => {
@@ -367,7 +371,7 @@ export class JupyterKernelWebSocket {
367371
}
368372

369373
// expect response
370-
this.cells[msg_id] = new CellExecution(onStdout, onStderr, onDisplayData)
374+
this.cells[msg_id] = new CellExecution(onStdout, onStderr, onResult)
371375
this.idAwaiter[msg_id] = (responseData: Execution) => {
372376
// stop timeout
373377
clearInterval(timeoutSet as number)

js/tests/streaming.test.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import { ProcessMessage } from 'e2b'
2+
import { CodeInterpreter, Result } from '../src'
3+
4+
import { expect, test } from 'vitest'
5+
6+
test('streaming output', async () => {
7+
const out: ProcessMessage[] = []
8+
const sandbox = await CodeInterpreter.create()
9+
await sandbox.notebook.execCell('print(1)', {
10+
onStdout: (msg) => out.push(msg)
11+
})
12+
13+
expect(out.length).toEqual(1)
14+
expect(out[0].line).toEqual('1\n')
15+
await sandbox.close()
16+
})
17+
18+
test('streaming error', async () => {
19+
const out: ProcessMessage[] = []
20+
const sandbox = await CodeInterpreter.create()
21+
await sandbox.notebook.execCell('import sys;print(1, file=sys.stderr)', {
22+
onStderr: (msg) => out.push(msg)
23+
})
24+
25+
expect(out.length).toEqual(1)
26+
expect(out[0].line).toEqual('1\n')
27+
await sandbox.close()
28+
})
29+
30+
test('streaming result', async () => {
31+
const out: Result[] = []
32+
const sandbox = await CodeInterpreter.create()
33+
const code = `
34+
import matplotlib.pyplot as plt
35+
import numpy as np
36+
37+
x = np.linspace(0, 20, 100)
38+
y = np.sin(x)
39+
40+
plt.plot(x, y)
41+
plt.show()
42+
43+
x
44+
`
45+
await sandbox.notebook.execCell(code, {
46+
onResult: (result) => out.push(result)
47+
})
48+
49+
expect(out.length).toEqual(2)
50+
await sandbox.close()
51+
})

python/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ print("world")
8686
"""
8787

8888
with CodeInterpreter() as sandbox:
89-
sandbox.notebook.exec_cell(code, on_stdout=print, on_stderr=print, on_display_data=(lambda data: print(data.text)))
89+
sandbox.notebook.exec_cell(code, on_stdout=print, on_stderr=print, on_result=(lambda result: print(result.text)))
9090
```
9191

9292
### Pre-installed Python packages inside the sandbox

python/e2b_code_interpreter/main.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from e2b.constants import TIMEOUT
1111

1212
from e2b_code_interpreter.messaging import JupyterKernelWebSocket
13-
from e2b_code_interpreter.models import KernelException, Execution
14-
13+
from e2b_code_interpreter.models import KernelException, Execution, Result
1514

1615
logger = logging.getLogger(__name__)
1716

@@ -66,7 +65,7 @@ def exec_cell(
6665
kernel_id: Optional[str] = None,
6766
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None,
6867
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None,
69-
on_display_data: Optional[Callable[[Dict[str, Any]], Any]] = None,
68+
on_result: Optional[Callable[[Result], Any]] = None,
7069
timeout: Optional[float] = TIMEOUT,
7170
) -> Execution:
7271
"""
@@ -76,7 +75,7 @@ def exec_cell(
7675
:param kernel_id: The ID of the kernel to execute the code on. If not provided, the default kernel is used.
7776
:param on_stdout: A callback function to handle standard output messages from the code execution.
7877
:param on_stderr: A callback function to handle standard error messages from the code execution.
79-
:param on_display_data: A callback function to handle display data messages from the code execution.
78+
:param on_result: A callback function to handle the result and display calls of the code execution.
8079
:param timeout: Timeout for the call
8180
8281
:return: Result of the execution
@@ -93,9 +92,7 @@ def exec_cell(
9392
logger.debug(f"Creating new websocket connection to kernel {kernel_id}")
9493
ws = self._connect_to_kernel_ws(kernel_id, timeout=timeout)
9594

96-
session_id = ws.send_execution_message(
97-
code, on_stdout, on_stderr, on_display_data
98-
)
95+
session_id = ws.send_execution_message(code, on_stdout, on_stderr, on_result)
9996
logger.debug(
10097
f"Sent execution message to kernel {kernel_id}, session_id: {session_id}"
10198
)

python/e2b_code_interpreter/messaging.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from e2b.utils.future import DeferredFuture
1515
from pydantic import ConfigDict, PrivateAttr, BaseModel
1616

17-
from e2b_code_interpreter.models import Execution, Result, Error, MIMEType
17+
from e2b_code_interpreter.models import Execution, Result, Error
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -26,21 +26,21 @@ class CellExecution:
2626
"""
2727

2828
input_accepted: bool = False
29-
on_stdout: Optional[Callable[[ProcessMessage], None]] = None
30-
on_stderr: Optional[Callable[[ProcessMessage], None]] = None
31-
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None
29+
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None
30+
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None
31+
on_result: Optional[Callable[[Result], Any]] = None
3232

3333
def __init__(
3434
self,
35-
on_stdout: Optional[Callable[[ProcessMessage], None]] = None,
36-
on_stderr: Optional[Callable[[ProcessMessage], None]] = None,
37-
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None,
35+
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None,
36+
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None,
37+
on_result: Optional[Callable[[Result], Any]] = None,
3838
):
3939
self.partial_result = Execution()
4040
self.execution = Future()
4141
self.on_stdout = on_stdout
4242
self.on_stderr = on_stderr
43-
self.on_display_data = on_display_data
43+
self.on_result = on_result
4444

4545

4646
class JupyterKernelWebSocket(BaseModel):
@@ -129,17 +129,17 @@ def _get_execute_request(msg_id: str, code: str) -> str:
129129
def send_execution_message(
130130
self,
131131
code: str,
132-
on_stdout: Optional[Callable[[ProcessMessage], None]] = None,
133-
on_stderr: Optional[Callable[[ProcessMessage], None]] = None,
134-
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None,
132+
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None,
133+
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None,
134+
on_result: Optional[Callable[[Result], Any]] = None,
135135
) -> str:
136136
message_id = str(uuid.uuid4())
137137
logger.debug(f"Sending execution message: {message_id}")
138138

139139
self._cells[message_id] = CellExecution(
140140
on_stdout=on_stdout,
141141
on_stderr=on_stderr,
142-
on_display_data=on_display_data,
142+
on_result=on_result,
143143
)
144144
request = self._get_execute_request(message_id, code)
145145
self._queue_in.put(request)
@@ -204,12 +204,13 @@ def _receive_message(self, data: dict):
204204
elif data["msg_type"] in "display_data":
205205
result = Result(is_main_result=False, data=data["content"]["data"])
206206
execution.results.append(result)
207-
if cell.on_display_data:
208-
cell.on_display_data(result)
207+
if cell.on_result:
208+
cell.on_result(result)
209209
elif data["msg_type"] == "execute_result":
210-
execution.results.append(
211-
Result(is_main_result=True, data=data["content"]["data"])
212-
)
210+
result = Result(is_main_result=True, data=data["content"]["data"])
211+
execution.results.append(result)
212+
if cell.on_result:
213+
cell.on_result(result)
213214
elif data["msg_type"] == "status":
214215
if data["content"]["execution_state"] == "idle":
215216
if cell.input_accepted:

python/tests/test_streaming.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from e2b_code_interpreter.main import CodeInterpreter
2+
3+
4+
def test_streaming_output():
5+
out = []
6+
with CodeInterpreter() as sandbox:
7+
def test(line) -> int:
8+
out.append(line)
9+
return 1
10+
sandbox.notebook.exec_cell("print(1)", on_stdout=test)
11+
12+
assert len(out) == 1
13+
assert out[0].line == "1\n"
14+
15+
16+
def test_streaming_error():
17+
out = []
18+
with CodeInterpreter() as sandbox:
19+
sandbox.notebook.exec_cell("import sys;print(1, file=sys.stderr)", on_stderr=out.append)
20+
21+
assert len(out) == 1
22+
assert out[0].line == "1\n"
23+
24+
25+
def test_streaming_result():
26+
code = """
27+
import matplotlib.pyplot as plt
28+
import numpy as np
29+
30+
x = np.linspace(0, 20, 100)
31+
y = np.sin(x)
32+
33+
plt.plot(x, y)
34+
plt.show()
35+
36+
x
37+
"""
38+
39+
out = []
40+
with CodeInterpreter() as sandbox:
41+
sandbox.notebook.exec_cell(code, on_result=out.append)
42+
43+
assert len(out) == 2
44+

0 commit comments

Comments
 (0)