diff --git a/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc b/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc index 845eeab1..8757e4a2 100644 --- a/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc +++ b/shardy/dialect/sdy/transforms/export/reshard_to_collectives.cc @@ -1314,6 +1314,17 @@ class CollectiveInserter { AxisToDimAndIndex outAxisToDimAndIndex; }; +// Assumes both `inSharding` and `outSharding` are non-null. +bool isEquivalentOnMesh(TensorShardingAttr inSharding, + TensorShardingAttr outSharding, ReshardOp reshardOp) { + if (inSharding.getMeshName() == outSharding.getMeshName()) { + return true; + } + MeshAttr inMesh = inSharding.getMesh(reshardOp); + MeshAttr outMesh = outSharding.getMesh(reshardOp); + return inMesh.equals(outMesh, /*ignoreDeviceOrder=*/true); +} + class ReshardPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1330,31 +1341,26 @@ class ReshardPattern : public OpConversionPattern { return rewriter.notifyMatchFailure( op, [](Diagnostic& diag) { diag << "Incompatible shardings"; }); } - if (inSharding.isFullyReplicated() && outSharding.isFullyReplicated()) { - rewriter.replaceOp(op, adaptor.getInput()); - return success(); - } - if (inSharding.getMeshName() != outSharding.getMeshName()) { - if (outSharding.isFullyReplicated()) { - // TODO(enver): Hard fail if output sharding has a different unreduced - // axes than the input sharding. Note that the out sharding may be fully - // replicated and still have different unreduced axes than the input - // sharding. - outSharding = TensorShardingAttr::getFullyClosedLike(inSharding); - // TODO(enver): Also check for input sharding is fully replicated. - } else { - MeshAttr inMesh = inSharding.getMesh(op); - MeshAttr outMesh = outSharding.getMesh(op); - // TODO(enver): Use MeshAttr::equals method instead. - if (outMesh.getAxes() != inMesh.getAxes() || - inMesh.getDeviceIds() == outMesh.getDeviceIds()) { - // We currently only support a reshard between different meshes if - // they have the same axes and different device ids, and at least one - // of the sharding isn't fully replicated. - return rewriter.notifyMatchFailure( - op, [](Diagnostic& diag) { diag << "Incompatible meshes"; }); - } + if (outSharding.isFullyReplicated()) { + if (inSharding.isFullyReplicated()) { + rewriter.replaceOp(op, adaptor.getInput()); + return success(); } + // TODO(enver): Hard fail if output sharding has a different unreduced + // axes than the input sharding. Note that the out sharding may be fully + // replicated and still have different unreduced axes than the input + // sharding. + outSharding = TensorShardingAttr::getFullyClosedLike(inSharding); + } + // TODO(enver): Set input mesh to output mesh if input sharding is fully + // replicated. It requires sdy.all_slice can handle that input and output + // has a different meshes. + if (!isEquivalentOnMesh(inSharding, outSharding, op)) { + // We currently only support a reshard between different meshes if + // they have the same axes and different device ids, and at least one + // of the sharding isn't fully replicated. + return rewriter.notifyMatchFailure( + op, [](Diagnostic& diag) { diag << "Incompatible meshes"; }); } // TODO(tomnatan): we should verify that the operand of ReshardOp has a @@ -1377,13 +1383,24 @@ struct ReshardToCollectivesPass target = std::make_shared(*context); target->addLegalOp(); - if (keepRedundantReshards) { - target->addDynamicallyLegalOp([](ReshardOp op) { - return isEquivalent(getSharding(op.getInput()), op.getSharding()); - }); - } else { - target->addIllegalOp(); - } + target->addDynamicallyLegalOp([&](ReshardOp op) { + TensorShardingAttr inSharding = getSharding(op.getInput()); + TensorShardingAttr outSharding = op.getSharding(); + if (keepRedundantReshards && isEquivalent(inSharding, outSharding)) { + return true; + } + // In case out sharding is fully replicated, the reshard is either erased + // (if input sharding is fully replicated too) or it adds an all gather + // (if input sharding is sharded), either way it is handled and it should + // be marked as illegal. + if (outSharding.isFullyReplicated()) { + return false; + } + if (inSharding && !isEquivalentOnMesh(inSharding, outSharding, op)) { + return true; + } + return false; + }); RewritePatternSet patternsInternal(context); patternsInternal.add(context); diff --git a/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir b/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir index 8fffbbc8..d1674829 100644 --- a/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir +++ b/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives.mlir @@ -86,6 +86,14 @@ func.func @reshard_from_sharded_to_fully_replicated_different_meshes(%arg0 : ten return %0 : tensor<24x8xf32> } +// CHECK-LABEL: func @reshard_from_sharded_to_sharded_different_meshes +func.func @reshard_from_sharded_to_sharded_different_meshes(%arg0 : tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh1d_6, [{"x"}, {}]>}) -> (tensor<24x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d_2x3, [{"x"}, {}]>}) { + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh2d_2x3, [{"x"}, {}]> + // CHECK-NEXT: return %[[RESHARD]] + %0 = sdy.reshard %arg0 <@mesh2d_2x3, [{"x"}, {}]> : tensor<24x8xf32> + return %0 : tensor<24x8xf32> +} + // CHECK-LABEL: func @all_gather_single_axis func.func @all_gather_single_axis(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"y"}, {"x"}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: sdy.all_gather [{}, {"x"}] %arg0 out_sharding=<@mesh2d, [{"y"}, {}]> diff --git a/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives_keep_redundant_reshards_true.mlir b/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives_keep_redundant_reshards_true.mlir index 6b1ebd44..bea97d45 100644 --- a/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives_keep_redundant_reshards_true.mlir +++ b/shardy/dialect/sdy/transforms/export/test/reshard_to_collectives_keep_redundant_reshards_true.mlir @@ -10,6 +10,14 @@ func.func @redundant_reshard(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.shardin return %0 : tensor<16x8xf32> } +// CHECK-LABEL: func @redundant_reshard_on_fully_replicated +func.func @redundant_reshard_on_fully_replicated(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{}, {}]>}) -> tensor<16x8xf32> { + // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh2d, [{}, {}]> + // CHECK-NEXT: return %[[RESHARD]] + %0 = sdy.reshard %arg0 <@mesh2d, [{}, {}]> : tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + // CHECK-LABEL: func @non_redundant_reshard func.func @non_redundant_reshard(%arg0 : tensor<16x8xf32> {sdy.sharding=#sdy.sharding<@mesh2d, [{"x"}, {"y"}]>}) -> tensor<16x8xf32> { // CHECK-NEXT: %[[CP:.*]] = sdy.collective_permute %arg0 out_sharding=<@mesh2d, [{"y"}, {"x"}]>