|
16 | 16 |
|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
| 19 | +import json |
19 | 20 | import multiprocessing |
| 21 | +import os |
20 | 22 | import signal |
| 23 | +import sys |
21 | 24 | import time |
22 | | -from multiprocessing.context import ForkProcess |
| 25 | +from multiprocessing.context import SpawnProcess |
23 | 26 | from typing import AsyncIterator, Awaitable, Callable, cast |
24 | 27 |
|
25 | 28 | import anyio |
|
33 | 36 | from .._transcript.types import TranscriptInfo |
34 | 37 | from . import _mp_common |
35 | 38 | from ._mp_common import ( |
| 39 | + DillCallable, |
36 | 40 | IPCContext, |
37 | 41 | LoggingItem, |
38 | 42 | MetricsItem, |
@@ -140,9 +144,9 @@ async def the_func( |
140 | 144 | remainder_tasks = task_count % max_processes |
141 | 145 | # Initialize shared IPC context that will be inherited by forked workers |
142 | 146 | _mp_common.ipc_context = IPCContext( |
143 | | - parse_function=parse_function, |
144 | | - scan_function=scan_function, |
145 | | - scan_completed=scan_completed, |
| 147 | + parse_function=DillCallable(parse_function), # type: ignore[arg-type] |
| 148 | + scan_function=DillCallable(scan_function), # type: ignore[arg-type] |
| 149 | + scan_completed=DillCallable(scan_completed), # type: ignore[arg-type] |
146 | 150 | prefetch_multiple=prefetch_multiple, |
147 | 151 | diagnostics=diagnostics, |
148 | 152 | overall_start_time=time.time(), |
@@ -267,9 +271,13 @@ async def _upstream_collector() -> None: |
267 | 271 |
|
268 | 272 | print_diagnostics("MP Collector", "Finished collecting all items") |
269 | 273 |
|
| 274 | + # Set environment variables for subprocess to configure sys.path before unpickling |
| 275 | + os.environ["INSPECT_SCOUT_SYS_PATH"] = json.dumps(sys.path) |
| 276 | + os.environ["INSPECT_SCOUT_WORKING_DIR"] = os.getcwd() |
| 277 | + |
270 | 278 | # Start worker processes directly |
271 | | - ctx = multiprocessing.get_context("fork") |
272 | | - processes: list[ForkProcess] = [] |
| 279 | + ctx = multiprocessing.get_context("spawn") |
| 280 | + processes: list[SpawnProcess] = [] |
273 | 281 | for worker_id in range(max_processes): |
274 | 282 | task_count_for_worker = base_tasks + ( |
275 | 283 | 1 if worker_id < remainder_tasks else 0 |
|
0 commit comments