Skip to content

Commit de8b956

Browse files
highkerfacebook-github-bot
authored andcommitted
link launch and sync conda/workspace locations
Summary: Make sure the conda/workspace locations during launch map with the locations when we sync. Differential Revision: D79516268
1 parent ce36dbd commit de8b956

File tree

3 files changed

+65
-14
lines changed

3 files changed

+65
-14
lines changed

python/monarch/_src/actor/proc_mesh.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import threading
1414
import warnings
1515
from contextlib import AbstractContextManager
16+
from pathlib import Path
1617

1718
from typing import (
1819
Any,
@@ -69,6 +70,7 @@
6970
from monarch._src.actor.endpoint import endpoint
7071
from monarch._src.actor.future import DeprecatedNotAFuture, Future
7172
from monarch._src.actor.shape import MeshTrait
73+
from monarch.tools.config import Config
7274

7375
HAS_TENSOR_ENGINE = False
7476
try:
@@ -356,7 +358,11 @@ def rank_tensor(self, dim: str | Sequence[str]) -> "Tensor":
356358
def rank_tensors(self) -> Dict[str, "Tensor"]:
357359
return self._device_mesh.ranks
358360

359-
async def sync_workspace(self, conda: bool = False, auto_reload: bool = False) -> None:
361+
# TODO(kiuk): once we have the Workspace dataclass in config,
362+
# we should pass in Workspace here instead of the full Config
363+
async def sync_workspace(
364+
self, config: Config, conda: bool = False, auto_reload: bool = False
365+
) -> None:
360366
if self._code_sync_client is None:
361367
self._code_sync_client = CodeSyncMeshClient.spawn_blocking(
362368
proc_mesh=await self._proc_mesh_for_asyncio_fixme,
@@ -368,27 +374,27 @@ async def sync_workspace(self, conda: bool = False, auto_reload: bool = False) -
368374
# The workspace shape (i.e. only perform one rsync per host).
369375
assert set(self._shape.labels).issubset({"gpus", "hosts"})
370376

371-
# TODO(agallagher): Is there a better way to infer/set the local
372-
# workspace dir, rather than use PWD?
373-
workspaces = [
374-
WorkspaceConfig(
375-
local=os.getcwd(),
376-
remote=RemoteWorkspace(
377-
location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
378-
shape=WorkspaceShape.shared("gpus"),
377+
workspaces = []
378+
if config.workspace is not None:
379+
workspaces.append(
380+
WorkspaceConfig(
381+
local=Path(config.workspace),
382+
remote=RemoteWorkspace(
383+
location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
384+
shape=WorkspaceShape.shared("gpus"),
385+
),
386+
method=CodeSyncMethod.Rsync,
379387
),
380-
method=CodeSyncMethod.Rsync,
381-
),
382-
]
388+
)
383389

384390
# If `conda` is set, also sync the currently activated conda env.
385391
conda_prefix = os.environ.get("CONDA_PREFIX")
386392
if conda and conda_prefix is not None:
387393
workspaces.append(
388394
WorkspaceConfig(
389-
local=conda_prefix,
395+
local=Path(conda_prefix),
390396
remote=RemoteWorkspace(
391-
location=WorkspaceLocation.FromEnvVar("CONDA_PREFIX"),
397+
location=WorkspaceLocation.FromEnvVar("CONDA_DIR"),
392398
shape=WorkspaceShape.shared("gpus"),
393399
),
394400
method=CodeSyncMethod.CondaSync,

python/monarch/tools/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Config:
3232

3333
scheduler: str = NOT_SET
3434
scheduler_args: dict[str, Any] = field(default_factory=dict)
35+
# TODO: make workspace a list
3536
workspace: Optional[str] = None
3637
dryrun: bool = False
3738
appdef: UnnamedAppDef = field(default_factory=UnnamedAppDef)

python/tests/test_python_actors.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
local_proc_mesh,
3737
proc_mesh,
3838
)
39+
from monarch.tools.config import defaults
3940
from typing_extensions import assert_type
4041

4142

@@ -769,6 +770,49 @@ async def test_same_actor_twice() -> None:
769770
await pm.stop()
770771

771772

773+
class LsActor(Actor):
774+
def __init__(self, workspace: str):
775+
self.workspace = workspace
776+
777+
@endpoint
778+
async def ls(self) -> list[str]:
779+
return os.listdir(self.workspace)
780+
781+
782+
async def test_sync_workspace() -> None:
783+
pm = await proc_mesh(gpus=1)
784+
785+
# create two workspaces: one for local and one for remote
786+
with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst:
787+
try:
788+
os.environ["WORKSPACE_DIR"] = workspace_dst
789+
config = defaults.config("slurm", workspace_src)
790+
await pm.sync_workspace(config, conda=False, auto_reload=True)
791+
792+
# now file in remote workspace initially
793+
am = await pm.spawn("ls", LsActor, workspace_dst)
794+
for item in list(am.ls.call().get()):
795+
assert len(item[1]) == 0
796+
797+
# write a file to local workspace
798+
file_path = os.path.join(workspace_src, "new_file")
799+
with open(file_path, "w") as f:
800+
f.write("hello world")
801+
f.flush()
802+
803+
# force a sync and it should populate on the dst workspace
804+
await pm.sync_workspace(config, conda=False, auto_reload=True)
805+
for item in list(am.ls.call().get()):
806+
assert len(item[1]) == 1
807+
assert item[1][0] == "new_file"
808+
file_path = os.path.join(workspace_dst, item[1][0])
809+
with open(file_path, "r") as f:
810+
assert f.readline() == "hello world"
811+
812+
finally:
813+
os.environ.pop("WORKSPACE_DIR", None) # Remove if it didn't
814+
815+
772816
class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
773817
async def test_actor_mesh_stop(self) -> None:
774818
pm = proc_mesh(gpus=2)

0 commit comments

Comments
 (0)