@@ -158,18 +158,14 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
158158// return %reshard : tensor<4x8xf32>
159159// ```
160160template <class OpTy >
161- void processDot (OpTy op, ArrayRef<TensorShardingAttr> inShardings ,
161+ void processDot (OpTy op, ShardingProjection& shardingProjection ,
162162 ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
163163 const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
164164 const Mesh& mesh) {
165165 if (outShardings.empty ()) {
166166 // Result doesn't have a sharding.
167167 return ;
168168 }
169- ShardingProjection shardingProjection =
170- ShardingProjection::build (inShardings, outShardings, shardingRule,
171- mesh.attr (), /* closedIfMissing=*/ true );
172-
173169 const TensorFactorShardings& lhsSharding = shardingProjection.getOperand (0 );
174170 const TensorFactorShardings& rhsSharding = shardingProjection.getOperand (1 );
175171 TensorFactorShardings& resultSharding =
@@ -391,6 +387,42 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
391387 return false ;
392388}
393389
390+ // Assume the results have unreduced axes.
391+ //
392+ // Returns `AxesPerFactor` with only its reduction factors are populated to have
393+ // common axes.
394+ //
395+ // Hard fails if some reduction factors do not have compatible shardings.
396+ AxesPerFactor getCommonAxesPerReductionFactor (
397+ Operation* op, const ShardingProjection& shardingProjection,
398+ OpShardingRuleAttr shardingRule) {
399+ // TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
400+ // factors, and simplify the following logic.
401+ AxesPerFactor commonAxesPerFactor =
402+ AxesPerFactor (shardingRule.getNumFactors ());
403+ for (int64_t reductionFactor : shardingRule.getReductionFactors ()) {
404+ // We only iterate operands since reduction factors are not in results.
405+ bool seen = false ;
406+ SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
407+ for (const TensorFactorShardings& tensorFactorSharding :
408+ shardingProjection.getOperands ()) {
409+ if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
410+ getFactorSharding (tensorFactorSharding, reductionFactor)) {
411+ if (seen) {
412+ SDY_CHECK (factorSharding->equals (commonAxes))
413+ << " For the operation " << op
414+ << " , the result has unreduced axes while the operand has "
415+ " incompatible sharding along reduction factors." ;
416+ } else {
417+ commonAxes = llvm::to_vector (*factorSharding);
418+ seen = true ;
419+ }
420+ }
421+ }
422+ }
423+ return commonAxesPerFactor;
424+ }
425+
394426// Inserts explicit reshards on the operands and results of `op` such that the
395427// sharding of `op` is compatible with its sharding rule.
396428//
@@ -400,31 +432,25 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
400432// - All op results have the same unreduced axes.
401433// - If the op has no results, none of the operands has unreduced axes.
402434// - Operand and result meshes are the same ignoring device id order.
435+ // - There are no overflow axes.
403436//
404437// Returns the union of axes along all the reduction factors which may not be
405438// canonicalized.
406- SmallVector<AxisRefAttr> processOp (Operation* op,
407- ArrayRef<TensorShardingAttr> inShardings,
408- ArrayRef<TensorShardingAttr> outShardings,
409- IRRewriter& rewriter,
410- const SymbolTable& symbolTable,
411- OpShardingRuleAttr shardingRule,
412- const Mesh& mesh, const bool onFullVersion) {
413- ShardingProjection shardingProjection = ShardingProjection::build (
414- inShardings, outShardings, shardingRule, mesh.attr (),
415- /* closedIfMissing=*/ true );
416-
417- // Return without inserting reshards if any factor sharding has overflow
418- // axes. This case is not handled yet.
419- // TODO(enver): Handle the case when factor shardings have overflow axes.
420- if (hasOverflowAxes (shardingProjection)) {
421- return {};
422- }
423-
439+ //
440+ // Guarantees to return non-empty `AxesPerFactor` if `onFullVersion` is true.
441+ AxesPerFactor processOp (Operation* op, ShardingProjection& shardingProjection,
442+ ArrayRef<TensorShardingAttr> inShardings,
443+ ArrayRef<TensorShardingAttr> outShardings,
444+ IRRewriter& rewriter, const SymbolTable& symbolTable,
445+ OpShardingRuleAttr shardingRule, const Mesh& mesh,
446+ const bool onFullVersion) {
447+ // Checks if factors are sharded the same way across operands and results.
448+ AxesPerFactor commonAxesPerFactor =
449+ getCompatibleFactorShardings (shardingProjection, shardingRule);
450+
451+ // TODO(b/446833985): Return common axes per factor also when the sharding
452+ // projection have overflow axes.
424453 if (onFullVersion) {
425- // Checks if factors are sharded the same way across operands and results.
426- AxesPerFactor commonAxesPerFactor =
427- getCompatibleFactorShardings (shardingProjection, shardingRule);
428454 // Find compatible shardings if it is not already compatible.
429455 if (commonAxesPerFactor.empty ()) {
430456 commonAxesPerFactor =
@@ -443,49 +469,19 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
443469 insertExplicitReshards (op, inShardings, outShardings, shardingProjection,
444470 updateTensorShardings, rewriter, shardingRule,
445471 symbolTable, mesh);
446-
447- return getReductionAxes (commonAxesPerFactor, shardingRule);
448- }
449-
450- TypeSwitch<Operation*>(op)
451- .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452- processDot (dotOp, inShardings, outShardings, rewriter, symbolTable,
453- shardingRule, mesh);
454- })
455- .Case <stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456- processDot (dotGeneralOp, inShardings, outShardings, rewriter,
457- symbolTable, shardingRule, mesh);
458- });
459-
460- if (outShardings.empty () || getUnreducedAxes (outShardings[0 ]).empty ()) {
461- return {};
472+ } else {
473+ TypeSwitch<Operation*>(op)
474+ .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
475+ processDot (dotOp, shardingProjection, outShardings, rewriter,
476+ symbolTable, shardingRule, mesh);
477+ })
478+ .Case <stablehlo::DotGeneralOp>(
479+ [&](stablehlo::DotGeneralOp dotGeneralOp) {
480+ processDot (dotGeneralOp, shardingProjection, outShardings,
481+ rewriter, symbolTable, shardingRule, mesh);
482+ });
462483 }
463-
464- // TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465- // factors, and simplify the following logic.
466- SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
467- for (int64_t reductionFactor : shardingRule.getReductionFactors ()) {
468- // We only iterate operands since reduction factors are not in results.
469- bool seen = false ;
470- SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
471- for (const TensorFactorShardings& tensorFactorSharding :
472- shardingProjection.getOperands ()) {
473- if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
474- getFactorSharding (tensorFactorSharding, reductionFactor)) {
475- if (seen) {
476- SDY_CHECK (axesAlongCurrentReductionFactor == *factorSharding)
477- << " For the operation " << op
478- << " , the result has unreduced axes while the operand has "
479- " incompatible sharding along reduction factors." ;
480- } else {
481- axesAlongCurrentReductionFactor = llvm::to_vector (*factorSharding);
482- seen = true ;
483- }
484- }
485- }
486- axesAlongAllReductionFactors.append (axesAlongCurrentReductionFactor);
487- }
488- return axesAlongAllReductionFactors;
484+ return commonAxesPerFactor;
489485}
490486
491487struct InsertExplicitReshardsPass
@@ -544,11 +540,35 @@ struct InsertExplicitReshardsPass
544540 return ;
545541 }
546542
547- SmallVector<AxisRefAttr> reductionAxes =
548- processOp (op, inShardings, outShardings, rewriter, symbolTable,
549- shardingRule, *mesh, onFullVersion);
543+ ShardingProjection shardingProjection = ShardingProjection::build (
544+ inShardings, outShardings, shardingRule, mesh->attr (),
545+ /* closedIfMissing=*/ true );
546+ // Return without inserting reshards if any factor sharding has overflow
547+ // axes. This case is not handled yet.
548+ // TODO(enver): Handle the case when factor shardings have overflow axes.
549+ if (hasOverflowAxes (shardingProjection)) {
550+ return ;
551+ }
552+ AxesPerFactor commonAxesPerFactor =
553+ processOp (op, shardingProjection, inShardings, outShardings, rewriter,
554+ symbolTable, shardingRule, *mesh, onFullVersion);
555+ if (op->getResults ().empty ()) {
556+ return ;
557+ }
558+ if (!onFullVersion) {
559+ if (getUnreducedAxes (op->getResult (0 )).empty ()) {
560+ return ;
561+ }
562+ if (commonAxesPerFactor.empty ()) {
563+ // At this point, there are unreduced axes on results.
564+ commonAxesPerFactor = getCommonAxesPerReductionFactor (
565+ op, shardingProjection, shardingRule);
566+ }
567+ }
550568 // TODO(b/440055868): Insert a reshard from unreduced to replicated axes.
551- insertAllReducesForReductionFactors (op, reductionAxes, *mesh, rewriter);
569+ insertAllReducesForReductionFactors (
570+ op, getReductionAxes (commonAxesPerFactor, shardingRule), *mesh,
571+ rewriter);
552572
553573 // TODO(enver): Remove sharding rules from ops.
554574 });
0 commit comments