diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 4d950c7e9..c735f98e1 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -54,11 +54,6 @@ def get_gradient_division() -> bool: def set_use_sync_collectives(val: bool) -> None: - if val and torch._running_with_deploy(): - raise RuntimeError( - "TorchRec sync_collectives are not supported in torch.deploy." - ) - global USE_SYNC_COLLECTIVES USE_SYNC_COLLECTIVES = val @@ -2356,43 +2351,42 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: return (None, None, myreq.dummy_tensor) -if not torch._running_with_deploy(): # noqa C901 - # Torch Library op def can not be used in Deploy - class AllToAllSingle(torch.autograd.Function): - @staticmethod - # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. - def forward( - # pyre-fixme[2]: Parameter must be annotated. - ctx, - input: Tensor, - output_split_sizes: List[int], - input_split_sizes: List[int], - group_name: str, - group_size: int, - gradient_division: bool, - ) -> Tensor: - ctx.output_split_sizes = input_split_sizes - ctx.input_split_sizes = output_split_sizes - ctx.group_name = group_name - ctx.group_size = group_size - ctx.gradient_division = gradient_division - return torch.distributed._functional_collectives.all_to_all_single( - input, output_split_sizes, input_split_sizes, group_name - ) +# Torch Library op def +class AllToAllSingle(torch.autograd.Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + group_name: str, + group_size: int, + gradient_division: bool, + ) -> Tensor: + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + ctx.group_name = group_name + ctx.group_size = group_size + ctx.gradient_division = gradient_division + return torch.distributed._functional_collectives.all_to_all_single( + input, output_split_sizes, input_split_sizes, group_name + ) - @staticmethod - # pyre-ignore - def backward(ctx, grad): - grad = torch.distributed._functional_collectives.all_to_all_single( - grad, - ctx.output_split_sizes, - ctx.input_split_sizes, - ctx.group_name, - ) - if ctx.gradient_division: - grad.div_(ctx.group_size) + @staticmethod + # pyre-ignore + def backward(ctx, grad): + grad = torch.distributed._functional_collectives.all_to_all_single( + grad, + ctx.output_split_sizes, + ctx.input_split_sizes, + ctx.group_name, + ) + if ctx.gradient_division: + grad.div_(ctx.group_size) - return grad, None, None, None, None, None + return grad, None, None, None, None, None # torchrec::reduce_scatter_tensor @torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=()) diff --git a/torchrec/distributed/train_pipeline/tracing.py b/torchrec/distributed/train_pipeline/tracing.py index 946348785..ab4267868 100644 --- a/torchrec/distributed/train_pipeline/tracing.py +++ b/torchrec/distributed/train_pipeline/tracing.py @@ -13,13 +13,7 @@ import torch -if not torch._running_with_deploy(): - from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 -else: - - class FSDP2: - pass - +from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.fx.immutable_collections import ( diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 9007f55bd..195cc1115 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -80,8 +80,7 @@ except ImportError: logger.warning("torchrec_use_sync_collectives is not available") -if not torch._running_with_deploy(): - torch.ops.import_module("fbgemm_gpu.sparse_ops") +torch.ops.import_module("fbgemm_gpu.sparse_ops") # Note: doesn't make much sense but better than throwing.