Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion physicsnemo/domain_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,20 @@
# In minumum versions are met, we can import the shard tensor and spec.

from ._shard_tensor_spec import ShardTensorSpec
from .shard_tensor import ShardTensor, scatter_tensor
from .shard_tensor import (
FSDPOutputTensorAdapter,
ShardTensor,
distribute_over_domain_for_fsdp,
scatter_tensor,
wrap_for_fsdp,
)

def register_custom_ops():
"""Register all custom ShardTensor ops and shard-aware wrappers.

Imports are deferred to this function to avoid an import cycle between
``shard_tensor`` and the individual op modules.
"""
# These imports will register the custom ops with the ShardTensor class.
# It's done here to avoid an import cycle.
from .custom_ops import ( # noqa: F401
Expand All @@ -69,3 +80,6 @@ def register_custom_ops():
ShardTensor = None
ShardTensorSpec = None
scatter_tensor = None
distribute_over_domain_for_fsdp = None
FSDPOutputTensorAdapter = None
wrap_for_fsdp = None
110 changes: 78 additions & 32 deletions physicsnemo/domain_parallel/custom_ops/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@
)

import torch
from torch.distributed.tensor._dtensor_spec import TensorMeta
from torch.distributed.tensor.placement_types import (
Partial,
Shard,
)

# noqa: E402
from physicsnemo.domain_parallel._shard_tensor_spec import (
ShardTensorSpec,
_stride_from_contiguous_shape_C_style,
)
from physicsnemo.domain_parallel.shard_tensor import ShardTensor

aten = torch.ops.aten
Expand Down Expand Up @@ -248,6 +253,62 @@ def compute_result_sharding_shapes(
return result_sharding_shapes


def build_reduction_result(
local_result: torch.Tensor,
input_tensor: ShardTensor,
placements: list[Partial | Shard],
sharding_shapes: dict[int, list[torch.Size]],
) -> ShardTensor:
r"""Construct a ShardTensor result from a local reduction output.

Builds the ``ShardTensorSpec`` directly from the already-computed placements
and sharding shapes, avoiding the overhead and autograd side-effects of
``ShardTensor.from_local``.

Parameters
----------
local_result : torch.Tensor
The locally-computed reduction result.
input_tensor : ShardTensor
The original input ShardTensor (used for device mesh).
placements : List[Union[Partial, Shard]]
Result placements from :func:`compute_result_placements`.
sharding_shapes : Dict[int, List[torch.Size]]
Result sharding shapes from :func:`compute_result_sharding_shapes`.

Returns
-------
ShardTensor
Wrapped result with correct sharding metadata.
"""
global_shape = list(local_result.shape)
for mesh_dim, placement in enumerate(placements):
if isinstance(placement, Shard):
tensor_dim = placement.dim
global_shape[tensor_dim] = sum(
s[tensor_dim] for s in sharding_shapes[mesh_dim]
)

stride = _stride_from_contiguous_shape_C_style(global_shape)
spec = ShardTensorSpec(
mesh=input_tensor.device_mesh,
placements=tuple(placements),
tensor_meta=TensorMeta(
shape=tuple(global_shape),
stride=stride,
dtype=local_result.dtype,
),
_local_shape=local_result.shape,
_sharding_shapes={dim: tuple(s) for dim, s in sharding_shapes.items()},
)
return ShardTensor.__new__(
ShardTensor,
local_tensor=local_result,
spec=spec,
requires_grad=input_tensor.requires_grad,
)


def create_sharded_grad_input(
local_grad_input: torch.Tensor, original_spec: Any
) -> ShardTensor:
Expand All @@ -265,11 +326,15 @@ def create_sharded_grad_input(
ShardTensor
A distributed tensor with the same sharding as the original input.
"""
return ShardTensor.from_local(
local_grad_input,
device_mesh=original_spec.mesh,
placements=original_spec.placements,
sharding_shapes=original_spec.sharding_shapes(),
# In custom autograd backward, return the input gradient directly as a
# ShardTensor value. Avoid ``from_local`` here (which routes through a
# separate autograd Function) so the gradient is attached unambiguously to
# the original ShardTensor input.
return ShardTensor.__new__(
ShardTensor,
local_tensor=local_grad_input,
spec=original_spec,
requires_grad=False,
)


Expand Down Expand Up @@ -361,24 +426,14 @@ def forward(
"""
dim, keepdim = ShardedReductionBase.setup_ctx(ctx, tensor, dim, keepdim)

# Get local tensor
local_tensor = tensor._local_tensor
# Perform local sum
local_result = aten.sum(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype)
local_result = aten.sum(
tensor._local_tensor, dim=dim, keepdim=keepdim, dtype=dtype
)

# Compute placements for the result
placements = compute_result_placements(tensor, dim, "sum")
output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim)

# Create result ShardTensor
result = ShardTensor.from_local(
local_result,
tensor.device_mesh,
placements,
sharding_shapes=output_sharding_shapes,
)
sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim)

return result
return build_reduction_result(local_result, tensor, placements, sharding_shapes)

@staticmethod
def backward(
Expand Down Expand Up @@ -495,23 +550,14 @@ def forward(
for d in reduction_dims:
weight *= local_shape[d] / global_shape[d]

# Perform local mean
# Perform local mean and apply weighting for uneven shards
local_result = aten.mean(local_tensor, dim=dim, keepdim=keepdim, dtype=dtype)
# Apply weighting
local_result = local_result * weight

placements = compute_result_placements(tensor, dim, "sum")
output_sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim)

# Create result ShardTensor
result = ShardTensor.from_local(
local_result,
tensor.device_mesh,
placements,
sharding_shapes=output_sharding_shapes,
)
sharding_shapes = compute_result_sharding_shapes(tensor, dim, keepdim)

return result
return build_reduction_result(local_result, tensor, placements, sharding_shapes)

@staticmethod
def backward(
Expand Down
Loading
Loading