Skip to content

mocked propagation #748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/monarch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from monarch.common.future import Future

from monarch.common.invocation import RemoteException
from monarch.common.mock_cuda import mock_cuda, mock_cuda_guard, unmock_cuda
from monarch.common.opaque_ref import OpaqueRef
from monarch.common.pipe import create_pipe, Pipe, remote_generator
from monarch.common.remote import remote
Expand Down Expand Up @@ -84,6 +85,9 @@
"Shape": ("monarch._src.actor.shape", "Shape"),
"NDSlice": ("monarch._src.actor.shape", "NDSlice"),
"Selection": ("monarch.common.selection", "Selection"),
"mock_cuda": ("monarch.common.mock_cuda", "mock_cuda"),
"mock_cuda_guard": ("monarch.common.mock_cuda", "mock_cuda_guard"),
"unmock_cuda": ("monarch.common.mock_cuda", "unmock_cuda"),
"OpaqueRef": ("monarch.common.opaque_ref", "OpaqueRef"),
"create_pipe": ("monarch.common.pipe", "create_pipe"),
"Pipe": ("monarch.common.pipe", "Pipe"),
Expand Down Expand Up @@ -153,6 +157,9 @@ def __getattr__(name):
"RemoteException",
"Shape",
"Selection",
"mock_cuda",
"mock_cuda_guard",
"unmock_cuda",
"NDSlice",
"OpaqueRef",
"create_pipe",
Expand Down
13 changes: 11 additions & 2 deletions python/monarch/_src/actor/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
)

from monarch._src.actor.future import Future
from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
from monarch._src.actor.tensor_engine_shim import (
_cached_propagation,
_mocked_propagation,
fake_call,
)

if TYPE_CHECKING:
from monarch._src.actor.actor_mesh import (
Expand Down Expand Up @@ -213,7 +217,12 @@ def _propagate(self, args, kwargs, fake_args, fake_kwargs):
elif self._propagator_arg == "inspect":
return None
elif self._propagator_arg == "mocked":
raise NotImplementedError("mocked propagation")
resolvable = getattr(self, "_resolvable", None)
if resolvable is None:
raise NotImplementedError(
"Mocked propagation is not implemented for actor endpoints."
)
return _mocked_propagation(self._resolvable, args, kwargs)
else:
return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)

Expand Down
4 changes: 4 additions & 0 deletions python/monarch/_src/actor/tensor_engine_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,9 @@ def actor_rref(endpoint, args_kwargs_tuple: bytes, refs: Sequence[Any]): ...
def _cached_propagation(_cache, rfunction, args, kwargs) -> Any: ...


@shim(module="monarch.common.remote")
def _mocked_propagation(rfunction, args, kwargs) -> Any: ...


@shim(module="monarch.common.fake")
def fake_call(fn, *args, **kwargs): ...
20 changes: 20 additions & 0 deletions python/monarch/common/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import monarch.common.messages as messages
import torch
from monarch._src.actor.shape import MeshTrait, NDSlice, Shape
from monarch.common.tree import flatten

from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -321,6 +322,7 @@ def get_info(self) -> DeviceMeshInfo:

_active: Optional[DeviceMesh] = None
_dispatch_enabled = False
_mock_dispatch = False


def get_active_mesh():
Expand All @@ -341,6 +343,14 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
if fnstr in self.allowed_local_accessors and not isinstance(args[0], Tensor):
return func(*args, **kwargs)
global _mock_dispatch
input_tensors, _ = flatten(
(args, kwargs), lambda x: isinstance(x, torch.Tensor)
)
if _mock_dispatch and (
len(input_tensors) == 0 or all([[t.is_cuda for t in input_tensors]])
):
return _remote(func, propagate="mocked")(*args, **kwargs)
return _remote(func, propagate=func)(*args, **kwargs)


Expand All @@ -362,6 +372,16 @@ def _dispatch():
_dispatch_enabled = False


def enable_mocked_dispatch() -> None:
global _mock_dispatch
_mock_dispatch = True


def disable_mocked_dispatch() -> None:
global _mock_dispatch
_mock_dispatch = False


_on_change: List[Callable] = []


Expand Down
22 changes: 22 additions & 0 deletions python/monarch/common/mock_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
import logging
from contextlib import contextmanager
from typing import Generator, Optional

Expand All @@ -15,11 +16,32 @@

_mock_cuda_stream: Optional[torch.cuda.Stream] = None

logger: logging.Logger = logging.getLogger(__name__)


def _mock_init_test() -> None:
global _mock_cuda_stream
base_mock_address = 1 << 48
with torch.cuda.stream(_mock_cuda_stream):
monarch.common._C.mock_cuda()
x = torch.rand(4, dtype=torch.float32, device="cuda")
monarch.common._C.unmock_cuda()
# x will result in a small pool (2MB) caching allocator allocation
segment_size = 2 * 1024 * 1024
# therefore we expect the address of x's allocation to be...
expected_address = base_mock_address - segment_size
assert (
x.untyped_storage().data_ptr() == expected_address
), "monarch mock initialization failed. please import mock_cuda at the top of your imports"
logger.info("monarch mock initialization succeeded")


def get_mock_cuda_stream() -> torch.cuda.Stream:
global _mock_cuda_stream
if _mock_cuda_stream is None:
_mock_cuda_stream = torch.cuda.Stream()
_mock_init_test()
assert _mock_cuda_stream is not None
return _mock_cuda_stream


Expand Down
77 changes: 76 additions & 1 deletion python/monarch/common/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from monarch._src.actor.endpoint import Endpoint
from monarch.common.device_mesh import RemoteProcessGroup
from monarch.common.fake import fake_call
from monarch.common.fake import _fake_mode, fake_call

from monarch.common.function import (
Propagator,
Expand All @@ -57,9 +57,12 @@
)
from monarch.common.messages import Dims

from monarch.common.mock_cuda import mock_cuda_guard
from monarch.common.tensor import dtensor_check, dtensor_dispatch, InputChecker
from monarch.common.tree import flatten, tree_map
from torch import autograd, distributed as dist
from torch.utils._python_dispatch import _disable_current_modes

from typing_extensions import ParamSpec

logger: Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -378,3 +381,75 @@ def _cached_propagation(_cache, rfunction: ResolvableFunction, args, kwargs):

output_tensors = fake_call(output_pattern.empty, [inputs_group.tensors])
return unflatten_result(output_tensors)


def _mocked_propagation(rfunction: ResolvableFunction, args, kwargs):
# need to break out of device_mesh dispatch mode to run local mocked execution
with _disable_current_modes():
# need to manually enable autograd version tracking as we may be
# being called inside torch_dispatch
with torch._C._SetExcludeDispatchKeyGuard(
torch._C.DispatchKey.ADInplaceOrView, False
):
fn = rfunction.resolve()

input_monarch_tensors, unflatten_monarch_input = flatten(
(args, kwargs), lambda x: isinstance(x, torch.Tensor)
)
if len(input_monarch_tensors) > 0:
assert all(
[[t._fake.is_cuda for t in input_monarch_tensors]]
), "all input tensors to mocked should be CUDA"
input_fake_tensors = [t._fake for t in input_monarch_tensors]
input_fake_group = TensorGroup(input_fake_tensors)

with mock_cuda_guard():
# create mocked real tensor versions of input tensors
mocked_real_input_tensors = input_fake_group.pattern.empty([[]])
# increment version of mocked tensors to match input tensors.
# looping over (fake_version - mocked_version) to account for
# potential tensor aliasing.
for fake_input_tensor, mocked_input_tensor in zip(
input_fake_tensors, mocked_real_input_tensors
):
fake_input_version = fake_input_tensor._version
mocked_input_version = mocked_input_tensor._version
assert (
mocked_input_version <= fake_input_version
), "mocked version should be <= than fake version"
for _ in range(fake_input_version - mocked_input_version):
torch.autograd.graph.increment_version(mocked_input_tensor)

for i in range(len(input_monarch_tensors)):
mocked_real_input_tensors[i].requires_grad = input_monarch_tensors[
i
].requires_grad

mocked_input_group = TensorGroup(mocked_real_input_tensors)
mocked_args, mocked_kwargs = unflatten_monarch_input(
mocked_real_input_tensors
)

mocked_result = fn(*mocked_args, **mocked_kwargs)

mocked_result_tensors, unflatten_result = flatten(
mocked_result, lambda x: isinstance(x, torch.Tensor)
)
mocked_output_group = TensorGroup(
mocked_result_tensors, parent=mocked_input_group
)

output_tensors = fake_call(
mocked_output_group.pattern.empty, [input_fake_group.tensors]
)

for mocked_result_tensor, fake_output_tensor in zip(
mocked_result_tensors, output_tensors
):
mocked_result_version = mocked_result_tensor._version
fake_output_version = fake_output_tensor._version
for _ in range(mocked_result_version - fake_output_version):
torch.autograd.graph.increment_version(fake_output_tensor)
assert mocked_result_tensor._version == fake_output_tensor._version

return unflatten_result(output_tensors)