|
41 | 41 |
|
42 | 42 | from monarch._src.actor.endpoint import Endpoint
|
43 | 43 | from monarch.common.device_mesh import RemoteProcessGroup
|
44 |
| -from monarch.common.fake import fake_call |
| 44 | +from monarch.common.fake import _fake_mode, fake_call |
45 | 45 |
|
46 | 46 | from monarch.common.function import (
|
47 | 47 | Propagator,
|
|
57 | 57 | )
|
58 | 58 | from monarch.common.messages import Dims
|
59 | 59 |
|
| 60 | +from monarch.common.mock_cuda import mock_cuda_guard |
60 | 61 | from monarch.common.tensor import dtensor_check, dtensor_dispatch, InputChecker
|
61 | 62 | from monarch.common.tree import flatten, tree_map
|
62 | 63 | from torch import autograd, distributed as dist
|
| 64 | +from torch.utils._python_dispatch import _disable_current_modes |
| 65 | + |
63 | 66 | from typing_extensions import ParamSpec
|
64 | 67 |
|
65 | 68 | logger: Logger = logging.getLogger(__name__)
|
@@ -378,3 +381,77 @@ def _cached_propagation(_cache, rfunction: ResolvableFunction, args, kwargs):
|
378 | 381 |
|
379 | 382 | output_tensors = fake_call(output_pattern.empty, [inputs_group.tensors])
|
380 | 383 | return unflatten_result(output_tensors)
|
| 384 | + |
| 385 | + |
| 386 | +def _mocked_propagation(rfunction: ResolvableFunction, args, kwargs): |
| 387 | + # need to break out of device_mesh dispatch mode to run local mocked execution |
| 388 | + with _disable_current_modes(): |
| 389 | + # need to manually enable autograd version tracking as we may be |
| 390 | + # being called inside torch_dispatch |
| 391 | + with torch._C._SetExcludeDispatchKeyGuard( |
| 392 | + torch._C.DispatchKey.ADInplaceOrView, False |
| 393 | + ): |
| 394 | + fn = rfunction.resolve() |
| 395 | + |
| 396 | + input_monarch_tensors, unflatten_monarch_input = flatten( |
| 397 | + (args, kwargs), lambda x: isinstance(x, torch.Tensor) |
| 398 | + ) |
| 399 | + if len(input_monarch_tensors) > 0: |
| 400 | + assert all( |
| 401 | + [[t._fake.is_cuda for t in input_monarch_tensors]] |
| 402 | + ), "all input tensors to mocked should be CUDA" |
| 403 | + input_fake_tensors = [t._fake for t in input_monarch_tensors] |
| 404 | + input_fake_group = TensorGroup(input_fake_tensors) |
| 405 | + |
| 406 | + with mock_cuda_guard(): |
| 407 | + # create mocked real tensor versions of input tensors |
| 408 | + mocked_real_input_tensors = input_fake_group.pattern.empty([[]]) |
| 409 | + # increment version of mocked tensors to match input tensors. |
| 410 | + # looping over (fake_version - mocked_version) to account for |
| 411 | + # potential tensor aliasing. |
| 412 | + for fake_input_tensor, mocked_input_tensor in zip( |
| 413 | + input_fake_tensors, mocked_real_input_tensors |
| 414 | + ): |
| 415 | + fake_input_version = fake_input_tensor._version |
| 416 | + mocked_input_version = mocked_input_tensor._version |
| 417 | + assert ( |
| 418 | + mocked_input_version <= fake_input_version |
| 419 | + ), "mocked version should be <= than fake version" |
| 420 | + for _ in range(fake_input_version - mocked_input_version): |
| 421 | + torch.autograd.graph.increment_version(mocked_input_tensor) |
| 422 | + |
| 423 | + for i in range(len(input_monarch_tensors)): |
| 424 | + mocked_real_input_tensors[i].requires_grad = input_monarch_tensors[ |
| 425 | + i |
| 426 | + ].requires_grad |
| 427 | + |
| 428 | + # TODO confirm is this is actually needed; or can just |
| 429 | + # use the input tensor group later directly |
| 430 | + mocked_input_group = TensorGroup(mocked_real_input_tensors) |
| 431 | + mocked_args, mocked_kwargs = unflatten_monarch_input( |
| 432 | + mocked_real_input_tensors |
| 433 | + ) |
| 434 | + |
| 435 | + mocked_result = fn(*mocked_args, **mocked_kwargs) |
| 436 | + |
| 437 | + mocked_result_tensors, unflatten_result = flatten( |
| 438 | + mocked_result, lambda x: isinstance(x, torch.Tensor) |
| 439 | + ) |
| 440 | + mocked_output_group = TensorGroup( |
| 441 | + mocked_result_tensors, parent=mocked_input_group |
| 442 | + ) |
| 443 | + |
| 444 | + output_tensors = fake_call( |
| 445 | + mocked_output_group.pattern.empty, [input_fake_group.tensors] |
| 446 | + ) |
| 447 | + |
| 448 | + for mocked_result_tensor, fake_output_tensor in zip( |
| 449 | + mocked_result_tensors, output_tensors |
| 450 | + ): |
| 451 | + mocked_result_version = mocked_result_tensor._version |
| 452 | + fake_output_version = fake_output_tensor._version |
| 453 | + for _ in range(mocked_result_version - fake_output_version): |
| 454 | + torch.autograd.graph.increment_version(fake_output_tensor) |
| 455 | + assert mocked_result_tensor._version == fake_output_tensor._version |
| 456 | + |
| 457 | + return unflatten_result(output_tensors) |
0 commit comments