Skip to content

Commit bd60904

Browse files
gsethi523facebook-github-bot
authored andcommitted
mocked propagation (#748)
Summary: Pull Request resolved: #748 Add support for enabling the `mocked_propagation` remote propagator path. Leverages `mock_cuda` functionality to perform the local tensor propagation computation without needing to use `fake tensor mode` for user function execution. Differential Revision: D79424676
1 parent 65b30b1 commit bd60904

File tree

6 files changed

+140
-3
lines changed

6 files changed

+140
-3
lines changed

python/monarch/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from monarch.common.future import Future
5151

5252
from monarch.common.invocation import RemoteException
53+
from monarch.common.mock_cuda import mock_cuda, mock_cuda_guard, unmock_cuda
5354
from monarch.common.opaque_ref import OpaqueRef
5455
from monarch.common.pipe import create_pipe, Pipe, remote_generator
5556
from monarch.common.remote import remote
@@ -84,6 +85,9 @@
8485
"Shape": ("monarch._src.actor.shape", "Shape"),
8586
"NDSlice": ("monarch._src.actor.shape", "NDSlice"),
8687
"Selection": ("monarch.common.selection", "Selection"),
88+
"mock_cuda": ("monarch.common.mock_cuda", "mock_cuda"),
89+
"mock_cuda_guard": ("monarch.common.mock_cuda", "mock_cuda_guard"),
90+
"unmock_cuda": ("monarch.common.mock_cuda", "unmock_cuda"),
8791
"OpaqueRef": ("monarch.common.opaque_ref", "OpaqueRef"),
8892
"create_pipe": ("monarch.common.pipe", "create_pipe"),
8993
"Pipe": ("monarch.common.pipe", "Pipe"),
@@ -153,6 +157,9 @@ def __getattr__(name):
153157
"RemoteException",
154158
"Shape",
155159
"Selection",
160+
"mock_cuda",
161+
"mock_cuda_guard",
162+
"unmock_cuda",
156163
"NDSlice",
157164
"OpaqueRef",
158165
"create_pipe",

python/monarch/_src/actor/endpoint.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
)
3333

3434
from monarch._src.actor.future import Future
35-
from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
35+
from monarch._src.actor.tensor_engine_shim import (
36+
_cached_propagation,
37+
_mocked_propagation,
38+
fake_call,
39+
)
3640

3741
if TYPE_CHECKING:
3842
from monarch._src.actor.actor_mesh import (
@@ -213,7 +217,12 @@ def _propagate(self, args, kwargs, fake_args, fake_kwargs):
213217
elif self._propagator_arg == "inspect":
214218
return None
215219
elif self._propagator_arg == "mocked":
216-
raise NotImplementedError("mocked propagation")
220+
resolvable = getattr(self, "_resolvable", None)
221+
if resolvable is None:
222+
raise NotImplementedError(
223+
"Mocked propagation is not implemented for actor endpoints."
224+
)
225+
return _mocked_propagation(self._resolvable, args, kwargs)
217226
else:
218227
return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
219228

python/monarch/_src/actor/tensor_engine_shim.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,9 @@ def actor_rref(endpoint, args_kwargs_tuple: bytes, refs: Sequence[Any]): ...
5555
def _cached_propagation(_cache, rfunction, args, kwargs) -> Any: ...
5656

5757

58+
@shim(module="monarch.common.remote")
59+
def _mocked_propagation(rfunction, args, kwargs) -> Any: ...
60+
61+
5862
@shim(module="monarch.common.fake")
5963
def fake_call(fn, *args, **kwargs): ...

python/monarch/common/device_mesh.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import monarch.common.messages as messages
3030
import torch
3131
from monarch._src.actor.shape import MeshTrait, NDSlice, Shape
32+
from monarch.common.tree import flatten
3233

3334
from torch.utils._python_dispatch import TorchDispatchMode
3435
from torch.utils._pytree import tree_map
@@ -321,6 +322,7 @@ def get_info(self) -> DeviceMeshInfo:
321322

322323
_active: Optional[DeviceMesh] = None
323324
_dispatch_enabled = False
325+
_mock_dispatch = False
324326

325327

326328
def get_active_mesh():
@@ -341,6 +343,14 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
341343
return func(*args, **kwargs)
342344
if fnstr in self.allowed_local_accessors and not isinstance(args[0], Tensor):
343345
return func(*args, **kwargs)
346+
global _mock_dispatch
347+
input_tensors, _ = flatten(
348+
(args, kwargs), lambda x: isinstance(x, torch.Tensor)
349+
)
350+
if _mock_dispatch and (
351+
len(input_tensors) == 0 or all([[t.is_cuda for t in input_tensors]])
352+
):
353+
return _remote(func, propagate="mocked")(*args, **kwargs)
344354
return _remote(func, propagate=func)(*args, **kwargs)
345355

346356

@@ -362,6 +372,16 @@ def _dispatch():
362372
_dispatch_enabled = False
363373

364374

375+
def enable_mocked_dispatch() -> None:
376+
global _mock_dispatch
377+
_mock_dispatch = True
378+
379+
380+
def disable_mocked_dispatch() -> None:
381+
global _mock_dispatch
382+
_mock_dispatch = False
383+
384+
365385
_on_change: List[Callable] = []
366386

367387

python/monarch/common/mock_cuda.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
import logging
89
from contextlib import contextmanager
910
from typing import Generator, Optional
1011

@@ -15,11 +16,32 @@
1516

1617
_mock_cuda_stream: Optional[torch.cuda.Stream] = None
1718

19+
logger: logging.Logger = logging.getLogger(__name__)
20+
21+
22+
def _mock_init_test() -> None:
23+
global _mock_cuda_stream
24+
base_mock_address = 1 << 48
25+
with torch.cuda.stream(_mock_cuda_stream):
26+
monarch.common._C.mock_cuda()
27+
x = torch.rand(4, dtype=torch.float32, device="cuda")
28+
monarch.common._C.unmock_cuda()
29+
# x will result in a small pool (2MB) caching allocator allocation
30+
segment_size = 2 * 1024 * 1024
31+
# therefore we expect the address of x's allocation to be...
32+
expected_address = base_mock_address - segment_size
33+
assert (
34+
x.untyped_storage().data_ptr() == expected_address
35+
), "monarch mock initialization failed. please import mock_cuda at the top of your imports"
36+
logger.info("monarch mock initialization succeeded")
37+
1838

1939
def get_mock_cuda_stream() -> torch.cuda.Stream:
2040
global _mock_cuda_stream
2141
if _mock_cuda_stream is None:
2242
_mock_cuda_stream = torch.cuda.Stream()
43+
_mock_init_test()
44+
assert _mock_cuda_stream is not None
2345
return _mock_cuda_stream
2446

2547

python/monarch/common/remote.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from monarch._src.actor.endpoint import Endpoint
4343
from monarch.common.device_mesh import RemoteProcessGroup
44-
from monarch.common.fake import fake_call
44+
from monarch.common.fake import _fake_mode, fake_call
4545

4646
from monarch.common.function import (
4747
Propagator,
@@ -57,9 +57,12 @@
5757
)
5858
from monarch.common.messages import Dims
5959

60+
from monarch.common.mock_cuda import mock_cuda_guard
6061
from monarch.common.tensor import dtensor_check, dtensor_dispatch, InputChecker
6162
from monarch.common.tree import flatten, tree_map
6263
from torch import autograd, distributed as dist
64+
from torch.utils._python_dispatch import _disable_current_modes
65+
6366
from typing_extensions import ParamSpec
6467

6568
logger: Logger = logging.getLogger(__name__)
@@ -378,3 +381,75 @@ def _cached_propagation(_cache, rfunction: ResolvableFunction, args, kwargs):
378381

379382
output_tensors = fake_call(output_pattern.empty, [inputs_group.tensors])
380383
return unflatten_result(output_tensors)
384+
385+
386+
def _mocked_propagation(rfunction: ResolvableFunction, args, kwargs):
387+
# need to break out of device_mesh dispatch mode to run local mocked execution
388+
with _disable_current_modes():
389+
# need to manually enable autograd version tracking as we may be
390+
# being called inside torch_dispatch
391+
with torch._C._SetExcludeDispatchKeyGuard(
392+
torch._C.DispatchKey.ADInplaceOrView, False
393+
):
394+
fn = rfunction.resolve()
395+
396+
input_monarch_tensors, unflatten_monarch_input = flatten(
397+
(args, kwargs), lambda x: isinstance(x, torch.Tensor)
398+
)
399+
if len(input_monarch_tensors) > 0:
400+
assert all(
401+
[[t._fake.is_cuda for t in input_monarch_tensors]]
402+
), "all input tensors to mocked should be CUDA"
403+
input_fake_tensors = [t._fake for t in input_monarch_tensors]
404+
input_fake_group = TensorGroup(input_fake_tensors)
405+
406+
with mock_cuda_guard():
407+
# create mocked real tensor versions of input tensors
408+
mocked_real_input_tensors = input_fake_group.pattern.empty([[]])
409+
# increment version of mocked tensors to match input tensors.
410+
# looping over (fake_version - mocked_version) to account for
411+
# potential tensor aliasing.
412+
for fake_input_tensor, mocked_input_tensor in zip(
413+
input_fake_tensors, mocked_real_input_tensors
414+
):
415+
fake_input_version = fake_input_tensor._version
416+
mocked_input_version = mocked_input_tensor._version
417+
assert (
418+
mocked_input_version <= fake_input_version
419+
), "mocked version should be <= than fake version"
420+
for _ in range(fake_input_version - mocked_input_version):
421+
torch.autograd.graph.increment_version(mocked_input_tensor)
422+
423+
for i in range(len(input_monarch_tensors)):
424+
mocked_real_input_tensors[i].requires_grad = input_monarch_tensors[
425+
i
426+
].requires_grad
427+
428+
mocked_input_group = TensorGroup(mocked_real_input_tensors)
429+
mocked_args, mocked_kwargs = unflatten_monarch_input(
430+
mocked_real_input_tensors
431+
)
432+
433+
mocked_result = fn(*mocked_args, **mocked_kwargs)
434+
435+
mocked_result_tensors, unflatten_result = flatten(
436+
mocked_result, lambda x: isinstance(x, torch.Tensor)
437+
)
438+
mocked_output_group = TensorGroup(
439+
mocked_result_tensors, parent=mocked_input_group
440+
)
441+
442+
output_tensors = fake_call(
443+
mocked_output_group.pattern.empty, [input_fake_group.tensors]
444+
)
445+
446+
for mocked_result_tensor, fake_output_tensor in zip(
447+
mocked_result_tensors, output_tensors
448+
):
449+
mocked_result_version = mocked_result_tensor._version
450+
fake_output_version = fake_output_tensor._version
451+
for _ in range(mocked_result_version - fake_output_version):
452+
torch.autograd.graph.increment_version(fake_output_tensor)
453+
assert mocked_result_tensor._version == fake_output_tensor._version
454+
455+
return unflatten_result(output_tensors)

0 commit comments

Comments
 (0)