Skip to content

Commit f485449

Browse files
jayasifacebook-github-bot
authored andcommitted
Change the setup lambda to take no params (#641)
Summary: Pull Request resolved: #641 To maintain consistency, not exposing MonarchContext to the public API. Users can call current_rank().rank to get the current rank in their setup method. Reviewed By: suo Differential Revision: D78929218 fbshipit-source-id: df2c5c19a8efb9b7d6e69b902a01edda57c0ae7a
1 parent e579c25 commit f485449

File tree

3 files changed

+30
-27
lines changed

3 files changed

+30
-27
lines changed

python/monarch/_src/actor/proc_mesh.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
Actor,
4444
ActorMeshRef,
4545
fake_sync_state,
46-
MonarchContext,
4746
)
4847

4948
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator, SimAllocator
@@ -89,7 +88,7 @@ class SetupActor(Actor):
8988
Typically used to setup the environment variables.
9089
"""
9190

92-
def __init__(self, env: Callable[[MonarchContext], None]) -> None:
91+
def __init__(self, env: Callable[[], None]) -> None:
9392
"""
9493
Initialize the setup actor with the user defined setup method.
9594
"""
@@ -100,8 +99,7 @@ async def setup(self) -> None:
10099
"""
101100
Call the user defined setup method with the monarch context.
102101
"""
103-
ctx = MonarchContext.get()
104-
self._setup_method(ctx)
102+
self._setup_method()
105103

106104

107105
T = TypeVar("T")
@@ -114,7 +112,7 @@ async def setup(self) -> None:
114112

115113

116114
async def _allocate_nonblocking(
117-
alloc: Alloc, setup: Callable[[MonarchContext], None] | None = None
115+
alloc: Alloc, setup: Callable[[], None] | None = None
118116
) -> "ProcMesh":
119117
_proc_mesh = await HyProcMesh.allocate_nonblocking(alloc)
120118
if setup is None:
@@ -211,15 +209,25 @@ async def monitor_loop(monitor):
211209

212210
@classmethod
213211
def from_alloc(
214-
self, alloc: Alloc, setup: Callable[[MonarchContext], None] | None = None
212+
self, alloc: Alloc, setup: Callable[[], None] | None = None
215213
) -> Future["ProcMesh"]:
216214
"""
217215
Allocate a process mesh according to the provided alloc.
218216
Returns when the mesh is fully allocated.
219217
220218
Arguments:
221219
- `alloc`: The alloc to allocate according to.
222-
- `setup`: A lambda taking MonarchContext as param, can be used to setup env vars on the allocated mesh
220+
- `setup`: An optional lambda function to configure environment variables on the allocated mesh.
221+
Use the `current_rank()` method within the lambda to obtain the rank.
222+
223+
Example of a setup method to initialize torch distributed environment variables:
224+
```
225+
def setup():
226+
rank = current_rank()
227+
os.environ["RANK"] = str(rank)
228+
os.environ["WORLD_SIZE"] = str(len(rank.shape))
229+
os.environ["LOCAL_RANK"] = str(rank["gpus"])
230+
```
223231
"""
224232
return Future(
225233
impl=lambda: _allocate_nonblocking(alloc, setup),
@@ -428,7 +436,7 @@ async def proc_mesh_nonblocking(
428436
gpus: Optional[int] = None,
429437
hosts: int = 1,
430438
env: dict[str, str] | None = None,
431-
setup: Callable[[MonarchContext], None] | None = None,
439+
setup: Callable[[], None] | None = None,
432440
) -> ProcMesh:
433441
if gpus is None:
434442
gpus = _local_device_count()
@@ -457,7 +465,7 @@ def proc_mesh(
457465
gpus: Optional[int] = None,
458466
hosts: int = 1,
459467
env: dict[str, str] | None = None,
460-
setup: Callable[[MonarchContext], None] | None = None,
468+
setup: Callable[[], None] | None = None,
461469
) -> Future[ProcMesh]:
462470
return Future(
463471
impl=lambda: proc_mesh_nonblocking(

python/tests/test_allocator.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
ChannelTransport,
3434
)
3535

36-
from monarch._src.actor.actor_mesh import MonarchContext
3736
from monarch._src.actor.allocator import (
3837
ALLOC_LABEL_PROC_MESH_NAME,
3938
LocalAllocator,
@@ -160,7 +159,7 @@ async def test_setup_lambda_with_multiple_env_vars(self) -> None:
160159
"TEST_ENV_VAR_3": "value_3",
161160
}
162161

163-
def setup_multiple_env_vars(ctx: MonarchContext) -> None:
162+
def setup_multiple_env_vars() -> None:
164163
for name, value in env_vars.items():
165164
os.environ[name] = value
166165

@@ -184,36 +183,33 @@ def setup_multiple_env_vars(ctx: MonarchContext) -> None:
184183
await proc_mesh.stop()
185184

186185
async def test_setup_lambda_with_context_info(self) -> None:
187-
"""Test that the setup lambda can access context information"""
188-
context_var_name: str = "PROC_MESH_CONTEXT_INFO"
186+
"""Test that the setup lambda can access rank information"""
187+
context_var_name: str = "PROC_MESH_RANK_INFO"
189188

190-
def setup_with_context(ctx: MonarchContext) -> None:
191-
context_info = f"proc_id:{ctx.proc_id},point_rank:{ctx.point.rank}"
189+
def setup_with_rank() -> None:
190+
context_info = f"point_rank:{current_rank().rank}"
192191
os.environ[context_var_name] = context_info
193192

194193
spec = AllocSpec(AllocConstraints(), gpus=1, hosts=1)
195194
allocator = LocalAllocator()
196195
alloc = await allocator.allocate(spec)
197196

198-
proc_mesh = await ProcMesh.from_alloc(alloc, setup=setup_with_context)
197+
proc_mesh = await ProcMesh.from_alloc(alloc, setup=setup_with_rank)
199198

200199
try:
201200
actor = await proc_mesh.spawn("env_check", EnvCheckActor)
202201

203-
context_info = await actor.get_env_var.call_one(context_var_name)
202+
rank_info = await actor.get_env_var.call_one(context_var_name)
204203

205204
self.assertNotEqual(
206-
context_info,
205+
rank_info,
207206
"NOT_SET",
208207
"Context information was not stored in the environment variable",
209208
)
210-
self.assertIn(
211-
"proc_id:", context_info, "Context information does not contain proc_id"
212-
)
213209
self.assertIn(
214210
"point_rank:0",
215-
context_info,
216-
f"Context information {context_info} does not contain point_rank",
211+
rank_info,
212+
f"Context information {rank_info} does not contain point_rank",
217213
)
218214
finally:
219215
await proc_mesh.stop()
@@ -435,7 +431,7 @@ async def test_setup_lambda_sets_env_vars(self) -> None:
435431
test_var_name: str = "TEST_ENV_VAR_FOR_PROC_MESH"
436432
test_var_value: str = "test_value_123"
437433

438-
def setup_env_vars(ctx: MonarchContext) -> None:
434+
def setup_env_vars() -> None:
439435
os.environ[test_var_name] = test_var_value
440436

441437
hosts = 2

python/tests/test_env_before_cuda.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import torch
1717
from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec
18-
from monarch._src.actor.actor_mesh import MonarchContext
1918
from monarch._src.actor.allocator import LocalAllocator
2019
from monarch._src.actor.proc_mesh import proc_mesh
2120
from monarch.actor import Actor, endpoint, ProcMesh
@@ -70,7 +69,7 @@ async def test_lambda_sets_env_vars_before_cuda_init(self) -> None:
7069
"CUDA_LAUNCH_BLOCKING": "1",
7170
}
7271

73-
def setup_cuda_env(_: MonarchContext) -> None:
72+
def setup_cuda_env() -> None:
7473
for name, value in cuda_env_vars.items():
7574
os.environ[name] = value
7675

@@ -107,7 +106,7 @@ async def test_proc_mesh_with_lambda_env(self) -> None:
107106
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
108107
}
109108

110-
def setup_cuda_env(_: MonarchContext) -> None:
109+
def setup_cuda_env() -> None:
111110
for name, value in cuda_env_vars.items():
112111
os.environ[name] = value
113112

0 commit comments

Comments
 (0)