Skip to content

Commit 9ab97bf

Browse files
Keep reshards that has non-equivalent input and output meshes.
InsertExplicitReshards pass do not insert reshards with non-equivalent input and output meshes. Still it is possible for reshard to collectives pass to have them from user sharding constraints. PiperOrigin-RevId: 815709360
1 parent 84c64a0 commit 9ab97bf

File tree

3 files changed

+64
-31
lines changed

3 files changed

+64
-31
lines changed

shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,17 @@ class CollectiveInserter {
13141314
AxisToDimAndIndex outAxisToDimAndIndex;
13151315
};
13161316

1317+
// Assumes both `inSharding` and `outSharding` are non-null.
1318+
bool isEquivalentOnMesh(TensorShardingAttr inSharding,
1319+
TensorShardingAttr outSharding, ReshardOp reshardOp) {
1320+
if (inSharding.getMeshName() == outSharding.getMeshName()) {
1321+
return true;
1322+
}
1323+
MeshAttr inMesh = inSharding.getMesh(reshardOp);
1324+
MeshAttr outMesh = outSharding.getMesh(reshardOp);
1325+
return inMesh.equals(outMesh, /*ignoreDeviceOrder=*/true);
1326+
}
1327+
13171328
class ReshardPattern : public OpConversionPattern<ReshardOp> {
13181329
public:
13191330
using OpConversionPattern::OpConversionPattern;
@@ -1330,31 +1341,26 @@ class ReshardPattern : public OpConversionPattern<ReshardOp> {
13301341
return rewriter.notifyMatchFailure(
13311342
op, [](Diagnostic& diag) { diag << "Incompatible shardings"; });
13321343
}
1333-
if (inSharding.isFullyReplicated() && outSharding.isFullyReplicated()) {
1334-
rewriter.replaceOp(op, adaptor.getInput());
1335-
return success();
1336-
}
1337-
if (inSharding.getMeshName() != outSharding.getMeshName()) {
1338-
if (outSharding.isFullyReplicated()) {
1339-
// TODO(enver): Hard fail if output sharding has a different unreduced
1340-
// axes than the input sharding. Note that the out sharding may be fully
1341-
// replicated and still have different unreduced axes than the input
1342-
// sharding.
1343-
outSharding = TensorShardingAttr::getFullyClosedLike(inSharding);
1344-
// TODO(enver): Also check for input sharding is fully replicated.
1345-
} else {
1346-
MeshAttr inMesh = inSharding.getMesh(op);
1347-
MeshAttr outMesh = outSharding.getMesh(op);
1348-
// TODO(enver): Use MeshAttr::equals method instead.
1349-
if (outMesh.getAxes() != inMesh.getAxes() ||
1350-
inMesh.getDeviceIds() == outMesh.getDeviceIds()) {
1351-
// We currently only support a reshard between different meshes if
1352-
// they have the same axes and different device ids, and at least one
1353-
// of the sharding isn't fully replicated.
1354-
return rewriter.notifyMatchFailure(
1355-
op, [](Diagnostic& diag) { diag << "Incompatible meshes"; });
1356-
}
1344+
if (outSharding.isFullyReplicated()) {
1345+
if (inSharding.isFullyReplicated()) {
1346+
rewriter.replaceOp(op, adaptor.getInput());
1347+
return success();
13571348
}
1349+
// TODO(enver): Hard fail if output sharding has a different unreduced
1350+
// axes than the input sharding. Note that the out sharding may be fully
1351+
// replicated and still have different unreduced axes than the input
1352+
// sharding.
1353+
outSharding = TensorShardingAttr::getFullyClosedLike(inSharding);
1354+
}
1355+
// TODO(enver): Set input mesh to output mesh if input sharding is fully
1356+
// replicated. It requires sdy.all_slice can handle that input and output
1357+
// has a different meshes.
1358+
if (!isEquivalentOnMesh(inSharding, outSharding, op)) {
1359+
// We currently only support a reshard between different meshes if
1360+
// they have the same axes and different device ids, and at least one
1361+
// of the sharding isn't fully replicated.
1362+
return rewriter.notifyMatchFailure(
1363+
op, [](Diagnostic& diag) { diag << "Incompatible meshes"; });
13581364
}
13591365

13601366
// TODO(tomnatan): we should verify that the operand of ReshardOp has a
@@ -1377,13 +1383,24 @@ struct ReshardToCollectivesPass
13771383
target = std::make_shared<ConversionTarget>(*context);
13781384
target->addLegalOp<AllGatherOp, AllSliceOp, AllToAllOp,
13791385
CollectivePermuteOp>();
1380-
if (keepRedundantReshards) {
1381-
target->addDynamicallyLegalOp<ReshardOp>([](ReshardOp op) {
1382-
return isEquivalent(getSharding(op.getInput()), op.getSharding());
1383-
});
1384-
} else {
1385-
target->addIllegalOp<ReshardOp>();
1386-
}
1386+
target->addDynamicallyLegalOp<ReshardOp>([&](ReshardOp op) {
1387+
TensorShardingAttr inSharding = getSharding(op.getInput());
1388+
TensorShardingAttr outSharding = op.getSharding();
1389+
if (keepRedundantReshards && isEquivalent(inSharding, outSharding)) {
1390+
return true;
1391+
}
1392+
// In case out sharding is fully replicated, the reshard is either erased
1393+
// (if input sharding is fully replicated too) or it adds an all gather
1394+
// (if input sharding is sharded), either way it is handled and it should
1395+
// be marked as illegal.
1396+
if (outSharding.isFullyReplicated()) {
1397+
return false;
1398+
}
1399+
if (inSharding && !isEquivalentOnMesh(inSharding, outSharding, op)) {
1400+
return true;
1401+
}
1402+
return false;
1403+
});
13871404

13881405
RewritePatternSet patternsInternal(context);
13891406
patternsInternal.add<ReshardPattern>(context);

shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ func.func @reshard_from_sharded_to_fully_replicated_different_meshes(%arg0 : ten
8686
return %0 : tensor<24x8xf32>
8787
}
8888

89+
// CHECK-LABEL: func @reshard_from_sharded_to_sharded_different_meshes
90+
func.func @reshard_from_sharded_to_sharded_different_meshes(%arg0 : tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x"}, {}]>}) -> (tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d_2x3, [{"x"}, {}]>}) {
91+
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh2d_2x3, [{"x"}, {}]>
92+
// CHECK-NEXT: return %[[RESHARD]]
93+
%0 = sdy.reshard %arg0 <@mesh2d_2x3, [{"x"}, {}]> : tensor<24x8xf32>
94+
return %0 : tensor<24x8xf32>
95+
}
96+
8997
// CHECK-LABEL: func @all_gather_single_axis
9098
func.func @all_gather_single_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> {
9199
// CHECK-NEXT: sdy.all_gather [{}, {"x"}] %arg0 out_sharding=<@mesh2d, [{"y"}, {}]>

shardy/dialect/sdy/transforms/export/test/reshard_to_collectives_keep_redundant_reshards_true.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ func.func @redundant_reshard(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.shardin
1010
return %0 : tensor<16x8xf32>
1111
}
1212

13+
// CHECK-LABEL: func @redundant_reshard_on_fully_replicated
14+
func.func @redundant_reshard_on_fully_replicated(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{}, {}]>}) -> tensor<16x8xf32> {
15+
// CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh2d, [{}, {}]>
16+
// CHECK-NEXT: return %[[RESHARD]]
17+
%0 = sdy.reshard %arg0 <@mesh2d, [{}, {}]> : tensor<16x8xf32>
18+
return %0 : tensor<16x8xf32>
19+
}
20+
1321
// CHECK-LABEL: func @non_redundant_reshard
1422
func.func @non_redundant_reshard(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> {
1523
// CHECK-NEXT: %[[CP:.*]] = sdy.collective_permute %arg0 out_sharding=<@mesh2d, [{"y"}, {"x"}]>

0 commit comments

Comments
 (0)