@@ -387,6 +387,42 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
387387  return  false ;
388388}
389389
390+ //  Assume the results have unreduced axes.
391+ // 
392+ //  Returns `AxesPerFactor` with only its reduction factors are populated to have
393+ //  common axes.
394+ // 
395+ //  Hard fails if some reduction factors do not have compatible shardings.
396+ AxesPerFactor getCommonAxesPerReductionFactor (
397+     Operation* op, const  ShardingProjection& shardingProjection,
398+     OpShardingRuleAttr shardingRule) {
399+   //  TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
400+   //  factors, and simplify the following logic.
401+   AxesPerFactor commonAxesPerFactor =
402+       AxesPerFactor (shardingRule.getNumFactors ());
403+   for  (int64_t  reductionFactor : shardingRule.getReductionFactors ()) {
404+     //  We only iterate operands since reduction factors are not in results.
405+     bool  seen = false ;
406+     SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
407+     for  (const  TensorFactorShardings& tensorFactorSharding :
408+          shardingProjection.getOperands ()) {
409+       if  (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
410+               getFactorSharding (tensorFactorSharding, reductionFactor)) {
411+         if  (seen) {
412+           SDY_CHECK (factorSharding->equals (commonAxes))
413+               << " For the operation " 
414+               << " , the result has unreduced axes while the operand has " 
415+                  " incompatible sharding along reduction factors." 
416+         } else  {
417+           commonAxes = llvm::to_vector (*factorSharding);
418+           seen = true ;
419+         }
420+       }
421+     }
422+   }
423+   return  commonAxesPerFactor;
424+ }
425+ 
390426//  Inserts explicit reshards on the operands and results of `op` such that the
391427//  sharding of `op` is compatible with its sharding rule.
392428// 
@@ -417,10 +453,13 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
417453    return  {};
418454  }
419455
456+   //  Checks if factors are sharded the same way across operands and results.
457+   AxesPerFactor commonAxesPerFactor =
458+       getCompatibleFactorShardings (shardingProjection, shardingRule);
459+ 
460+   //  TODO(b/446833985): Return common axes per factor also when the sharding
461+   //  projection have overflow axes.
420462  if  (onFullVersion) {
421-     //  Checks if factors are sharded the same way across operands and results.
422-     AxesPerFactor commonAxesPerFactor =
423-         getCompatibleFactorShardings (shardingProjection, shardingRule);
424463    //  Find compatible shardings if it is not already compatible.
425464    if  (commonAxesPerFactor.empty ()) {
426465      commonAxesPerFactor =
@@ -439,49 +478,30 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
439478    insertExplicitReshards (op, inShardings, outShardings, shardingProjection,
440479                           updateTensorShardings, rewriter, shardingRule,
441480                           symbolTable, mesh);
481+   } else  {
482+     TypeSwitch<Operation*>(op)
483+         .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
484+           processDot (dotOp, shardingProjection, outShardings, rewriter,
485+                      symbolTable, shardingRule, mesh);
486+         })
487+         .Case <stablehlo::DotGeneralOp>(
488+             [&](stablehlo::DotGeneralOp dotGeneralOp) {
489+               processDot (dotGeneralOp, shardingProjection, outShardings,
490+                          rewriter, symbolTable, shardingRule, mesh);
491+             });
492+ 
493+     if  (outShardings.empty () || getUnreducedAxes (outShardings[0 ]).empty ()) {
494+       return  {};
495+     }
442496
443-     return  getReductionAxes (commonAxesPerFactor, shardingRule);
444-   }
445- 
446-   TypeSwitch<Operation*>(op)
447-       .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
448-         processDot (dotOp, shardingProjection, outShardings, rewriter,
449-                    symbolTable, shardingRule, mesh);
450-       })
451-       .Case <stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
452-         processDot (dotGeneralOp, shardingProjection, outShardings, rewriter,
453-                    symbolTable, shardingRule, mesh);
454-       });
455- 
456-   if  (outShardings.empty () || getUnreducedAxes (outShardings[0 ]).empty ()) {
457-     return  {};
458-   }
459- 
460-   //  TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
461-   //  factors, and simplify the following logic.
462-   SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
463-   for  (int64_t  reductionFactor : shardingRule.getReductionFactors ()) {
464-     //  We only iterate operands since reduction factors are not in results.
465-     bool  seen = false ;
466-     SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
467-     for  (const  TensorFactorShardings& tensorFactorSharding :
468-          shardingProjection.getOperands ()) {
469-       if  (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
470-               getFactorSharding (tensorFactorSharding, reductionFactor)) {
471-         if  (seen) {
472-           SDY_CHECK (axesAlongCurrentReductionFactor == *factorSharding)
473-               << " For the operation " 
474-               << " , the result has unreduced axes while the operand has " 
475-                  " incompatible sharding along reduction factors." 
476-         } else  {
477-           axesAlongCurrentReductionFactor = llvm::to_vector (*factorSharding);
478-           seen = true ;
479-         }
480-       }
497+     if  (commonAxesPerFactor.empty ()) {
498+       //  At this point, there are unreduced axes on results.
499+       commonAxesPerFactor =
500+           getCommonAxesPerReductionFactor (op, shardingProjection, shardingRule);
481501    }
482-     axesAlongAllReductionFactors.append (axesAlongCurrentReductionFactor);
483502  }
484-   return  axesAlongAllReductionFactors;
503+ 
504+   return  getReductionAxes (commonAxesPerFactor, shardingRule);
485505}
486506
487507struct  InsertExplicitReshardsPass 
0 commit comments