Skip to content

Commit 96bf6c6

Browse files
Unify flows for finding common axes and building reduction axes out of it.
It is a refactoring. PiperOrigin-RevId: 811395434
1 parent 7870392 commit 96bf6c6

File tree

1 file changed

+63
-43
lines changed

1 file changed

+63
-43
lines changed

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,42 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
387387
return false;
388388
}
389389

390+
// Assume the results have unreduced axes.
391+
//
392+
// Returns `AxesPerFactor` with only its reduction factors are populated to have
393+
// common axes.
394+
//
395+
// Hard fails if some reduction factors do not have compatible shardings.
396+
AxesPerFactor getCommonAxesPerReductionFactor(
397+
Operation* op, const ShardingProjection& shardingProjection,
398+
OpShardingRuleAttr shardingRule) {
399+
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
400+
// factors, and simplify the following logic.
401+
AxesPerFactor commonAxesPerFactor =
402+
AxesPerFactor(shardingRule.getNumFactors());
403+
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
404+
// We only iterate operands since reduction factors are not in results.
405+
bool seen = false;
406+
SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
407+
for (const TensorFactorShardings& tensorFactorSharding :
408+
shardingProjection.getOperands()) {
409+
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
410+
getFactorSharding(tensorFactorSharding, reductionFactor)) {
411+
if (seen) {
412+
SDY_CHECK(factorSharding->equals(commonAxes))
413+
<< "For the operation " << op
414+
<< ", the result has unreduced axes while the operand has "
415+
"incompatible sharding along reduction factors.";
416+
} else {
417+
commonAxes = llvm::to_vector(*factorSharding);
418+
seen = true;
419+
}
420+
}
421+
}
422+
}
423+
return commonAxesPerFactor;
424+
}
425+
390426
// Inserts explicit reshards on the operands and results of `op` such that the
391427
// sharding of `op` is compatible with its sharding rule.
392428
//
@@ -417,10 +453,13 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
417453
return {};
418454
}
419455

456+
// Checks if factors are sharded the same way across operands and results.
457+
AxesPerFactor commonAxesPerFactor =
458+
getCompatibleFactorShardings(shardingProjection, shardingRule);
459+
460+
// TODO(b/446833985): Return common axes per factor also when the sharding
461+
// projection have overflow axes.
420462
if (onFullVersion) {
421-
// Checks if factors are sharded the same way across operands and results.
422-
AxesPerFactor commonAxesPerFactor =
423-
getCompatibleFactorShardings(shardingProjection, shardingRule);
424463
// Find compatible shardings if it is not already compatible.
425464
if (commonAxesPerFactor.empty()) {
426465
commonAxesPerFactor =
@@ -439,49 +478,30 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
439478
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
440479
updateTensorShardings, rewriter, shardingRule,
441480
symbolTable, mesh);
481+
} else {
482+
TypeSwitch<Operation*>(op)
483+
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
484+
processDot(dotOp, shardingProjection, outShardings, rewriter,
485+
symbolTable, shardingRule, mesh);
486+
})
487+
.Case<stablehlo::DotGeneralOp>(
488+
[&](stablehlo::DotGeneralOp dotGeneralOp) {
489+
processDot(dotGeneralOp, shardingProjection, outShardings,
490+
rewriter, symbolTable, shardingRule, mesh);
491+
});
492+
493+
if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
494+
return {};
495+
}
442496

443-
return getReductionAxes(commonAxesPerFactor, shardingRule);
444-
}
445-
446-
TypeSwitch<Operation*>(op)
447-
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
448-
processDot(dotOp, shardingProjection, outShardings, rewriter,
449-
symbolTable, shardingRule, mesh);
450-
})
451-
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
452-
processDot(dotGeneralOp, shardingProjection, outShardings, rewriter,
453-
symbolTable, shardingRule, mesh);
454-
});
455-
456-
if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
457-
return {};
458-
}
459-
460-
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
461-
// factors, and simplify the following logic.
462-
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
463-
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
464-
// We only iterate operands since reduction factors are not in results.
465-
bool seen = false;
466-
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
467-
for (const TensorFactorShardings& tensorFactorSharding :
468-
shardingProjection.getOperands()) {
469-
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
470-
getFactorSharding(tensorFactorSharding, reductionFactor)) {
471-
if (seen) {
472-
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
473-
<< "For the operation " << op
474-
<< ", the result has unreduced axes while the operand has "
475-
"incompatible sharding along reduction factors.";
476-
} else {
477-
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
478-
seen = true;
479-
}
480-
}
497+
if (commonAxesPerFactor.empty()) {
498+
// At this point, there are unreduced axes on results.
499+
commonAxesPerFactor =
500+
getCommonAxesPerReductionFactor(op, shardingProjection, shardingRule);
481501
}
482-
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
483502
}
484-
return axesAlongAllReductionFactors;
503+
504+
return getReductionAxes(commonAxesPerFactor, shardingRule);
485505
}
486506

487507
struct InsertExplicitReshardsPass

0 commit comments

Comments
 (0)