@@ -1314,6 +1314,17 @@ class CollectiveInserter {
13141314 AxisToDimAndIndex outAxisToDimAndIndex;
13151315};
13161316
1317+ // Assumes both `inSharding` and `outSharding` are non-null.
1318+ bool isEquivalentOnMesh (TensorShardingAttr inSharding,
1319+ TensorShardingAttr outSharding, ReshardOp reshardOp) {
1320+ if (inSharding.getMeshName () == outSharding.getMeshName ()) {
1321+ return true ;
1322+ }
1323+ MeshAttr inMesh = inSharding.getMesh (reshardOp);
1324+ MeshAttr outMesh = outSharding.getMesh (reshardOp);
1325+ return inMesh.equals (outMesh, /* ignoreDeviceOrder=*/ true );
1326+ }
1327+
13171328class ReshardPattern : public OpConversionPattern <ReshardOp> {
13181329 public:
13191330 using OpConversionPattern::OpConversionPattern;
@@ -1330,31 +1341,26 @@ class ReshardPattern : public OpConversionPattern<ReshardOp> {
13301341 return rewriter.notifyMatchFailure (
13311342 op, [](Diagnostic& diag) { diag << " Incompatible shardings" ; });
13321343 }
1333- if (inSharding.isFullyReplicated () && outSharding.isFullyReplicated ()) {
1334- rewriter.replaceOp (op, adaptor.getInput ());
1335- return success ();
1336- }
1337- if (inSharding.getMeshName () != outSharding.getMeshName ()) {
1338- if (outSharding.isFullyReplicated ()) {
1339- // TODO(enver): Hard fail if output sharding has a different unreduced
1340- // axes than the input sharding. Note that the out sharding may be fully
1341- // replicated and still have different unreduced axes than the input
1342- // sharding.
1343- outSharding = TensorShardingAttr::getFullyClosedLike (inSharding);
1344- // TODO(enver): Also check for input sharding is fully replicated.
1345- } else {
1346- MeshAttr inMesh = inSharding.getMesh (op);
1347- MeshAttr outMesh = outSharding.getMesh (op);
1348- // TODO(enver): Use MeshAttr::equals method instead.
1349- if (outMesh.getAxes () != inMesh.getAxes () ||
1350- inMesh.getDeviceIds () == outMesh.getDeviceIds ()) {
1351- // We currently only support a reshard between different meshes if
1352- // they have the same axes and different device ids, and at least one
1353- // of the sharding isn't fully replicated.
1354- return rewriter.notifyMatchFailure (
1355- op, [](Diagnostic& diag) { diag << " Incompatible meshes" ; });
1356- }
1344+ if (outSharding.isFullyReplicated ()) {
1345+ if (inSharding.isFullyReplicated ()) {
1346+ rewriter.replaceOp (op, adaptor.getInput ());
1347+ return success ();
13571348 }
1349+ // TODO(enver): Hard fail if output sharding has a different unreduced
1350+ // axes than the input sharding. Note that the out sharding may be fully
1351+ // replicated and still have different unreduced axes than the input
1352+ // sharding.
1353+ outSharding = TensorShardingAttr::getFullyClosedLike (inSharding);
1354+ }
1355+ // TODO(enver): Set input mesh to output mesh if input sharding is fully
1356+ // replicated. It requires sdy.all_slice can handle that input and output
1357+ // has a different meshes.
1358+ if (!isEquivalentOnMesh (inSharding, outSharding, op)) {
1359+ // We currently only support a reshard between different meshes if
1360+ // they have the same axes and different device ids, and at least one
1361+ // of the sharding isn't fully replicated.
1362+ return rewriter.notifyMatchFailure (
1363+ op, [](Diagnostic& diag) { diag << " Incompatible meshes" ; });
13581364 }
13591365
13601366 // TODO(tomnatan): we should verify that the operand of ReshardOp has a
@@ -1377,13 +1383,24 @@ struct ReshardToCollectivesPass
13771383 target = std::make_shared<ConversionTarget>(*context);
13781384 target->addLegalOp <AllGatherOp, AllSliceOp, AllToAllOp,
13791385 CollectivePermuteOp>();
1380- if (keepRedundantReshards) {
1381- target->addDynamicallyLegalOp <ReshardOp>([](ReshardOp op) {
1382- return isEquivalent (getSharding (op.getInput ()), op.getSharding ());
1383- });
1384- } else {
1385- target->addIllegalOp <ReshardOp>();
1386- }
1386+ target->addDynamicallyLegalOp <ReshardOp>([&](ReshardOp op) {
1387+ TensorShardingAttr inSharding = getSharding (op.getInput ());
1388+ TensorShardingAttr outSharding = op.getSharding ();
1389+ if (keepRedundantReshards && isEquivalent (inSharding, outSharding)) {
1390+ return true ;
1391+ }
1392+ // In case out sharding is fully replicated, the reshard is either erased
1393+ // (if input sharding is fully replicated too) or it adds an all gather
1394+ // (if input sharding is sharded), either way it is handled and it should
1395+ // be marked as illegal.
1396+ if (outSharding.isFullyReplicated ()) {
1397+ return false ;
1398+ }
1399+ if (inSharding && !isEquivalentOnMesh (inSharding, outSharding, op)) {
1400+ return true ;
1401+ }
1402+ return false ;
1403+ });
13871404
13881405 RewritePatternSet patternsInternal (context);
13891406 patternsInternal.add <ReshardPattern>(context);
0 commit comments