Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 48 additions & 31 deletions shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,17 @@ class CollectiveInserter {
AxisToDimAndIndex outAxisToDimAndIndex;
};

// Assumes both `inSharding` and `outSharding` are non-null.
bool isEquivalentOnMesh(TensorShardingAttr inSharding,
TensorShardingAttr outSharding, ReshardOp reshardOp) {
if (inSharding.getMeshName() == outSharding.getMeshName()) {
return true;
}
MeshAttr inMesh = inSharding.getMesh(reshardOp);
MeshAttr outMesh = outSharding.getMesh(reshardOp);
return inMesh.equals(outMesh, /*ignoreDeviceOrder=*/true);
}

class ReshardPattern : public OpConversionPattern<ReshardOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand All @@ -1330,31 +1341,26 @@ class ReshardPattern : public OpConversionPattern<ReshardOp> {
return rewriter.notifyMatchFailure(
op, [](Diagnostic& diag) { diag << "Incompatible shardings"; });
}
if (inSharding.isFullyReplicated() && outSharding.isFullyReplicated()) {
rewriter.replaceOp(op, adaptor.getInput());
return success();
}
if (inSharding.getMeshName() != outSharding.getMeshName()) {
if (outSharding.isFullyReplicated()) {
// TODO(enver): Hard fail if output sharding has a different unreduced
// axes than the input sharding. Note that the out sharding may be fully
// replicated and still have different unreduced axes than the input
// sharding.
outSharding = TensorShardingAttr::getFullyClosedLike(inSharding);
// TODO(enver): Also check for input sharding is fully replicated.
} else {
MeshAttr inMesh = inSharding.getMesh(op);
MeshAttr outMesh = outSharding.getMesh(op);
// TODO(enver): Use MeshAttr::equals method instead.
if (outMesh.getAxes() != inMesh.getAxes() ||
inMesh.getDeviceIds() == outMesh.getDeviceIds()) {
// We currently only support a reshard between different meshes if
// they have the same axes and different device ids, and at least one
// of the sharding isn't fully replicated.
return rewriter.notifyMatchFailure(
op, [](Diagnostic& diag) { diag << "Incompatible meshes"; });
}
if (outSharding.isFullyReplicated()) {
if (inSharding.isFullyReplicated()) {
rewriter.replaceOp(op, adaptor.getInput());
return success();
}
// TODO(enver): Hard fail if output sharding has a different unreduced
// axes than the input sharding. Note that the out sharding may be fully
// replicated and still have different unreduced axes than the input
// sharding.
outSharding = TensorShardingAttr::getFullyClosedLike(inSharding);
}
// TODO(enver): Set input mesh to output mesh if input sharding is fully
// replicated. It requires sdy.all_slice can handle that input and output
// has a different meshes.
if (!isEquivalentOnMesh(inSharding, outSharding, op)) {
// We currently only support a reshard between different meshes if
// they have the same axes and different device ids, and at least one
// of the sharding isn't fully replicated.
return rewriter.notifyMatchFailure(
op, [](Diagnostic& diag) { diag << "Incompatible meshes"; });
}

// TODO(tomnatan): we should verify that the operand of ReshardOp has a
Expand All @@ -1377,13 +1383,24 @@ struct ReshardToCollectivesPass
target = std::make_shared<ConversionTarget>(*context);
target->addLegalOp<AllGatherOp, AllSliceOp, AllToAllOp,
CollectivePermuteOp>();
if (keepRedundantReshards) {
target->addDynamicallyLegalOp<ReshardOp>([](ReshardOp op) {
return isEquivalent(getSharding(op.getInput()), op.getSharding());
});
} else {
target->addIllegalOp<ReshardOp>();
}
target->addDynamicallyLegalOp<ReshardOp>([&](ReshardOp op) {
TensorShardingAttr inSharding = getSharding(op.getInput());
TensorShardingAttr outSharding = op.getSharding();
if (keepRedundantReshards && isEquivalent(inSharding, outSharding)) {
return true;
}
// In case out sharding is fully replicated, the reshard is either erased
// (if input sharding is fully replicated too) or it adds an all gather
// (if input sharding is sharded), either way it is handled and it should
// be marked as illegal.
if (outSharding.isFullyReplicated()) {
return false;
}
if (inSharding && !isEquivalentOnMesh(inSharding, outSharding, op)) {
return true;
}
return false;
});

RewritePatternSet patternsInternal(context);
patternsInternal.add<ReshardPattern>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ func.func @reshard_from_sharded_to_fully_replicated_different_meshes(%arg0 : ten
return %0 : tensor<24x8xf32>
}

// CHECK-LABEL: func @reshard_from_sharded_to_sharded_different_meshes
func.func @reshard_from_sharded_to_sharded_different_meshes(%arg0 : tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x"}, {}]>}) -> (tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d_2x3, [{"x"}, {}]>}) {
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh2d_2x3, [{"x"}, {}]>
// CHECK-NEXT: return %[[RESHARD]]
%0 = sdy.reshard %arg0 <@mesh2d_2x3, [{"x"}, {}]> : tensor<24x8xf32>
return %0 : tensor<24x8xf32>
}

// CHECK-LABEL: func @all_gather_single_axis
func.func @all_gather_single_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: sdy.all_gather [{}, {"x"}] %arg0 out_sharding=<@mesh2d, [{"y"}, {}]>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ func.func @redundant_reshard(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.shardin
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @redundant_reshard_on_fully_replicated
func.func @redundant_reshard_on_fully_replicated(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{}, {}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh2d, [{}, {}]>
// CHECK-NEXT: return %[[RESHARD]]
%0 = sdy.reshard %arg0 <@mesh2d, [{}, {}]> : tensor<16x8xf32>
return %0 : tensor<16x8xf32>
}

// CHECK-LABEL: func @non_redundant_reshard
func.func @non_redundant_reshard(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> {
// CHECK-NEXT: %[[CP:.*]] = sdy.collective_permute %arg0 out_sharding=<@mesh2d, [{"y"}, {"x"}]>
Expand Down
Loading