diff --git a/python/monarch/__init__.py b/python/monarch/__init__.py index 7517c1ea8..0ce06b3fe 100644 --- a/python/monarch/__init__.py +++ b/python/monarch/__init__.py @@ -50,6 +50,7 @@ from monarch.common.future import Future from monarch.common.invocation import RemoteException + from monarch.common.mock_cuda import mock_cuda, mock_cuda_guard, unmock_cuda from monarch.common.opaque_ref import OpaqueRef from monarch.common.pipe import create_pipe, Pipe, remote_generator from monarch.common.remote import remote @@ -84,6 +85,9 @@ "Shape": ("monarch._src.actor.shape", "Shape"), "NDSlice": ("monarch._src.actor.shape", "NDSlice"), "Selection": ("monarch.common.selection", "Selection"), + "mock_cuda": ("monarch.common.mock_cuda", "mock_cuda"), + "mock_cuda_guard": ("monarch.common.mock_cuda", "mock_cuda_guard"), + "unmock_cuda": ("monarch.common.mock_cuda", "unmock_cuda"), "OpaqueRef": ("monarch.common.opaque_ref", "OpaqueRef"), "create_pipe": ("monarch.common.pipe", "create_pipe"), "Pipe": ("monarch.common.pipe", "Pipe"), @@ -153,6 +157,9 @@ def __getattr__(name): "RemoteException", "Shape", "Selection", + "mock_cuda", + "mock_cuda_guard", + "unmock_cuda", "NDSlice", "OpaqueRef", "create_pipe", diff --git a/python/monarch/_src/actor/endpoint.py b/python/monarch/_src/actor/endpoint.py index d1361e6cc..8bf650d4f 100644 --- a/python/monarch/_src/actor/endpoint.py +++ b/python/monarch/_src/actor/endpoint.py @@ -32,7 +32,11 @@ ) from monarch._src.actor.future import Future -from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call +from monarch._src.actor.tensor_engine_shim import ( + _cached_propagation, + _mocked_propagation, + fake_call, +) if TYPE_CHECKING: from monarch._src.actor.actor_mesh import ( @@ -213,7 +217,12 @@ def _propagate(self, args, kwargs, fake_args, fake_kwargs): elif self._propagator_arg == "inspect": return None elif self._propagator_arg == "mocked": - raise NotImplementedError("mocked propagation") + resolvable = getattr(self, "_resolvable", None) + if resolvable is None: + raise NotImplementedError( + "Mocked propagation is not implemented for actor endpoints." + ) + return _mocked_propagation(self._resolvable, args, kwargs) else: return fake_call(self._propagator_arg, *fake_args, **fake_kwargs) diff --git a/python/monarch/_src/actor/tensor_engine_shim.py b/python/monarch/_src/actor/tensor_engine_shim.py index e229b6f6b..0f50dc270 100644 --- a/python/monarch/_src/actor/tensor_engine_shim.py +++ b/python/monarch/_src/actor/tensor_engine_shim.py @@ -55,5 +55,9 @@ def actor_rref(endpoint, args_kwargs_tuple: bytes, refs: Sequence[Any]): ... def _cached_propagation(_cache, rfunction, args, kwargs) -> Any: ... +@shim(module="monarch.common.remote") +def _mocked_propagation(rfunction, args, kwargs) -> Any: ... + + @shim(module="monarch.common.fake") def fake_call(fn, *args, **kwargs): ... diff --git a/python/monarch/common/device_mesh.py b/python/monarch/common/device_mesh.py index 4704dcd9e..7e27603b7 100644 --- a/python/monarch/common/device_mesh.py +++ b/python/monarch/common/device_mesh.py @@ -29,6 +29,7 @@ import monarch.common.messages as messages import torch from monarch._src.actor.shape import MeshTrait, NDSlice, Shape +from monarch.common.tree import flatten from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map @@ -321,6 +322,7 @@ def get_info(self) -> DeviceMeshInfo: _active: Optional[DeviceMesh] = None _dispatch_enabled = False +_mock_dispatch = False def get_active_mesh(): @@ -341,6 +343,14 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) if fnstr in self.allowed_local_accessors and not isinstance(args[0], Tensor): return func(*args, **kwargs) + global _mock_dispatch + input_tensors, _ = flatten( + (args, kwargs), lambda x: isinstance(x, torch.Tensor) + ) + if _mock_dispatch and ( + len(input_tensors) == 0 or all([[t.is_cuda for t in input_tensors]]) + ): + return _remote(func, propagate="mocked")(*args, **kwargs) return _remote(func, propagate=func)(*args, **kwargs) @@ -362,6 +372,16 @@ def _dispatch(): _dispatch_enabled = False +def enable_mocked_dispatch() -> None: + global _mock_dispatch + _mock_dispatch = True + + +def disable_mocked_dispatch() -> None: + global _mock_dispatch + _mock_dispatch = False + + _on_change: List[Callable] = [] diff --git a/python/monarch/common/mock_cuda.py b/python/monarch/common/mock_cuda.py index 87fca239e..011e1f46c 100644 --- a/python/monarch/common/mock_cuda.py +++ b/python/monarch/common/mock_cuda.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +import logging from contextlib import contextmanager from typing import Generator, Optional @@ -15,11 +16,32 @@ _mock_cuda_stream: Optional[torch.cuda.Stream] = None +logger: logging.Logger = logging.getLogger(__name__) + + +def _mock_init_test() -> None: + global _mock_cuda_stream + base_mock_address = 1 << 48 + with torch.cuda.stream(_mock_cuda_stream): + monarch.common._C.mock_cuda() + x = torch.rand(4, dtype=torch.float32, device="cuda") + monarch.common._C.unmock_cuda() + # x will result in a small pool (2MB) caching allocator allocation + segment_size = 2 * 1024 * 1024 + # therefore we expect the address of x's allocation to be... + expected_address = base_mock_address - segment_size + assert ( + x.untyped_storage().data_ptr() == expected_address + ), "monarch mock initialization failed. please import mock_cuda at the top of your imports" + logger.info("monarch mock initialization succeeded") + def get_mock_cuda_stream() -> torch.cuda.Stream: global _mock_cuda_stream if _mock_cuda_stream is None: _mock_cuda_stream = torch.cuda.Stream() + _mock_init_test() + assert _mock_cuda_stream is not None return _mock_cuda_stream diff --git a/python/monarch/common/remote.py b/python/monarch/common/remote.py index 44b873713..772a1f969 100644 --- a/python/monarch/common/remote.py +++ b/python/monarch/common/remote.py @@ -41,7 +41,7 @@ from monarch._src.actor.endpoint import Endpoint from monarch.common.device_mesh import RemoteProcessGroup -from monarch.common.fake import fake_call +from monarch.common.fake import _fake_mode, fake_call from monarch.common.function import ( Propagator, @@ -57,9 +57,12 @@ ) from monarch.common.messages import Dims +from monarch.common.mock_cuda import mock_cuda_guard from monarch.common.tensor import dtensor_check, dtensor_dispatch, InputChecker from monarch.common.tree import flatten, tree_map from torch import autograd, distributed as dist +from torch.utils._python_dispatch import _disable_current_modes + from typing_extensions import ParamSpec logger: Logger = logging.getLogger(__name__) @@ -378,3 +381,75 @@ def _cached_propagation(_cache, rfunction: ResolvableFunction, args, kwargs): output_tensors = fake_call(output_pattern.empty, [inputs_group.tensors]) return unflatten_result(output_tensors) + + +def _mocked_propagation(rfunction: ResolvableFunction, args, kwargs): + # need to break out of device_mesh dispatch mode to run local mocked execution + with _disable_current_modes(): + # need to manually enable autograd version tracking as we may be + # being called inside torch_dispatch + with torch._C._SetExcludeDispatchKeyGuard( + torch._C.DispatchKey.ADInplaceOrView, False + ): + fn = rfunction.resolve() + + input_monarch_tensors, unflatten_monarch_input = flatten( + (args, kwargs), lambda x: isinstance(x, torch.Tensor) + ) + if len(input_monarch_tensors) > 0: + assert all( + [[t._fake.is_cuda for t in input_monarch_tensors]] + ), "all input tensors to mocked should be CUDA" + input_fake_tensors = [t._fake for t in input_monarch_tensors] + input_fake_group = TensorGroup(input_fake_tensors) + + with mock_cuda_guard(): + # create mocked real tensor versions of input tensors + mocked_real_input_tensors = input_fake_group.pattern.empty([[]]) + # increment version of mocked tensors to match input tensors. + # looping over (fake_version - mocked_version) to account for + # potential tensor aliasing. + for fake_input_tensor, mocked_input_tensor in zip( + input_fake_tensors, mocked_real_input_tensors + ): + fake_input_version = fake_input_tensor._version + mocked_input_version = mocked_input_tensor._version + assert ( + mocked_input_version <= fake_input_version + ), "mocked version should be <= than fake version" + for _ in range(fake_input_version - mocked_input_version): + torch.autograd.graph.increment_version(mocked_input_tensor) + + for i in range(len(input_monarch_tensors)): + mocked_real_input_tensors[i].requires_grad = input_monarch_tensors[ + i + ].requires_grad + + mocked_input_group = TensorGroup(mocked_real_input_tensors) + mocked_args, mocked_kwargs = unflatten_monarch_input( + mocked_real_input_tensors + ) + + mocked_result = fn(*mocked_args, **mocked_kwargs) + + mocked_result_tensors, unflatten_result = flatten( + mocked_result, lambda x: isinstance(x, torch.Tensor) + ) + mocked_output_group = TensorGroup( + mocked_result_tensors, parent=mocked_input_group + ) + + output_tensors = fake_call( + mocked_output_group.pattern.empty, [input_fake_group.tensors] + ) + + for mocked_result_tensor, fake_output_tensor in zip( + mocked_result_tensors, output_tensors + ): + mocked_result_version = mocked_result_tensor._version + fake_output_version = fake_output_tensor._version + for _ in range(mocked_result_version - fake_output_version): + torch.autograd.graph.increment_version(fake_output_tensor) + assert mocked_result_tensor._version == fake_output_tensor._version + + return unflatten_result(output_tensors)