diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc index 2f213741..92b94462 100644 --- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -387,6 +387,42 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) { return false; } +// Assume the results have unreduced axes. +// +// Returns `AxesPerFactor` with only its reduction factors are populated to have +// common axes. +// +// Hard fails if some reduction factors do not have compatible shardings. +AxesPerFactor getCommonAxesPerReductionFactor( + Operation* op, const ShardingProjection& shardingProjection, + OpShardingRuleAttr shardingRule) { + // TODO(enver): Repurpose getCompatibleFactorShardings to return compatible + // factors, and simplify the following logic. + AxesPerFactor commonAxesPerFactor = + AxesPerFactor(shardingRule.getNumFactors()); + for (int64_t reductionFactor : shardingRule.getReductionFactors()) { + // We only iterate operands since reduction factors are not in results. + bool seen = false; + SmallVector& commonAxes = commonAxesPerFactor[reductionFactor]; + for (const TensorFactorShardings& tensorFactorSharding : + shardingProjection.getOperands()) { + if (std::optional> factorSharding = + getFactorSharding(tensorFactorSharding, reductionFactor)) { + if (seen) { + SDY_CHECK(factorSharding->equals(commonAxes)) + << "For the operation " << op + << ", the result has unreduced axes while the operand has " + "incompatible sharding along reduction factors."; + } else { + commonAxes = llvm::to_vector(*factorSharding); + seen = true; + } + } + } + } + return commonAxesPerFactor; +} + // Inserts explicit reshards on the operands and results of `op` such that the // sharding of `op` is compatible with its sharding rule. // @@ -417,10 +453,13 @@ SmallVector processOp(Operation* op, return {}; } + // Checks if factors are sharded the same way across operands and results. + AxesPerFactor commonAxesPerFactor = + getCompatibleFactorShardings(shardingProjection, shardingRule); + + // TODO(b/446833985): Return common axes per factor also when the sharding + // projection have overflow axes. if (onFullVersion) { - // Checks if factors are sharded the same way across operands and results. - AxesPerFactor commonAxesPerFactor = - getCompatibleFactorShardings(shardingProjection, shardingRule); // Find compatible shardings if it is not already compatible. if (commonAxesPerFactor.empty()) { commonAxesPerFactor = @@ -439,49 +478,30 @@ SmallVector processOp(Operation* op, insertExplicitReshards(op, inShardings, outShardings, shardingProjection, updateTensorShardings, rewriter, shardingRule, symbolTable, mesh); + } else { + TypeSwitch(op) + .Case([&](stablehlo::DotOp dotOp) { + processDot(dotOp, shardingProjection, outShardings, rewriter, + symbolTable, shardingRule, mesh); + }) + .Case( + [&](stablehlo::DotGeneralOp dotGeneralOp) { + processDot(dotGeneralOp, shardingProjection, outShardings, + rewriter, symbolTable, shardingRule, mesh); + }); + + if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) { + return {}; + } - return getReductionAxes(commonAxesPerFactor, shardingRule); - } - - TypeSwitch(op) - .Case([&](stablehlo::DotOp dotOp) { - processDot(dotOp, shardingProjection, outShardings, rewriter, - symbolTable, shardingRule, mesh); - }) - .Case([&](stablehlo::DotGeneralOp dotGeneralOp) { - processDot(dotGeneralOp, shardingProjection, outShardings, rewriter, - symbolTable, shardingRule, mesh); - }); - - if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) { - return {}; - } - - // TODO(enver): Repurpose getCompatibleFactorShardings to return compatible - // factors, and simplify the following logic. - SmallVector axesAlongAllReductionFactors; - for (int64_t reductionFactor : shardingRule.getReductionFactors()) { - // We only iterate operands since reduction factors are not in results. - bool seen = false; - SmallVector axesAlongCurrentReductionFactor; - for (const TensorFactorShardings& tensorFactorSharding : - shardingProjection.getOperands()) { - if (std::optional> factorSharding = - getFactorSharding(tensorFactorSharding, reductionFactor)) { - if (seen) { - SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding) - << "For the operation " << op - << ", the result has unreduced axes while the operand has " - "incompatible sharding along reduction factors."; - } else { - axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding); - seen = true; - } - } + if (commonAxesPerFactor.empty()) { + // At this point, there are unreduced axes on results. + commonAxesPerFactor = + getCommonAxesPerReductionFactor(op, shardingProjection, shardingRule); } - axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor); } - return axesAlongAllReductionFactors; + + return getReductionAxes(commonAxesPerFactor, shardingRule); } struct InsertExplicitReshardsPass