13
13
import threading
14
14
import warnings
15
15
from contextlib import AbstractContextManager
16
+ from pathlib import Path
16
17
17
18
from typing import (
18
19
Any ,
69
70
from monarch ._src .actor .endpoint import endpoint
70
71
from monarch ._src .actor .future import DeprecatedNotAFuture , Future
71
72
from monarch ._src .actor .shape import MeshTrait
73
+ from monarch .tools .config import Config
72
74
73
75
HAS_TENSOR_ENGINE = False
74
76
try :
@@ -356,7 +358,11 @@ def rank_tensor(self, dim: str | Sequence[str]) -> "Tensor":
356
358
def rank_tensors (self ) -> Dict [str , "Tensor" ]:
357
359
return self ._device_mesh .ranks
358
360
359
- async def sync_workspace (self , conda : bool = False , auto_reload : bool = False ) -> None :
361
+ # TODO(kiuk): once we have the Workspace dataclass in config,
362
+ # we should pass in Workspace here instead of the full Config
363
+ async def sync_workspace (
364
+ self , config : Config , conda : bool = False , auto_reload : bool = False
365
+ ) -> None :
360
366
if self ._code_sync_client is None :
361
367
self ._code_sync_client = CodeSyncMeshClient .spawn_blocking (
362
368
proc_mesh = await self ._proc_mesh_for_asyncio_fixme ,
@@ -368,27 +374,27 @@ async def sync_workspace(self, conda: bool = False, auto_reload: bool = False) -
368
374
# The workspace shape (i.e. only perform one rsync per host).
369
375
assert set (self ._shape .labels ).issubset ({"gpus" , "hosts" })
370
376
371
- # TODO(agallagher): Is there a better way to infer/set the local
372
- # workspace dir, rather than use PWD?
373
- workspaces = [
374
- WorkspaceConfig (
375
- local = os .getcwd (),
376
- remote = RemoteWorkspace (
377
- location = WorkspaceLocation .FromEnvVar ("WORKSPACE_DIR" ),
378
- shape = WorkspaceShape .shared ("gpus" ),
377
+ workspaces = []
378
+ if config .workspace is not None :
379
+ workspaces .append (
380
+ WorkspaceConfig (
381
+ local = Path (config .workspace ),
382
+ remote = RemoteWorkspace (
383
+ location = WorkspaceLocation .FromEnvVar ("WORKSPACE_DIR" ),
384
+ shape = WorkspaceShape .shared ("gpus" ),
385
+ ),
386
+ method = CodeSyncMethod .Rsync ,
379
387
),
380
- method = CodeSyncMethod .Rsync ,
381
- ),
382
- ]
388
+ )
383
389
384
390
# If `conda` is set, also sync the currently activated conda env.
385
391
conda_prefix = os .environ .get ("CONDA_PREFIX" )
386
392
if conda and conda_prefix is not None :
387
393
workspaces .append (
388
394
WorkspaceConfig (
389
- local = conda_prefix ,
395
+ local = Path ( conda_prefix ) ,
390
396
remote = RemoteWorkspace (
391
- location = WorkspaceLocation .FromEnvVar ("CONDA_PREFIX " ),
397
+ location = WorkspaceLocation .FromEnvVar ("CONDA_DIR " ),
392
398
shape = WorkspaceShape .shared ("gpus" ),
393
399
),
394
400
method = CodeSyncMethod .CondaSync ,
0 commit comments