Skip to content

Commit b3529ed

Browse files
Refactor to move logic related to allreduce out of process op method.
It prepares to move some of it to allreduce logic. Also move the check for overflow axes out. This way, it guarantees to return non-empty AxesPerFactor in the full version. PiperOrigin-RevId: 812750842
1 parent b426498 commit b3529ed

File tree

1 file changed

+89
-67
lines changed

1 file changed

+89
-67
lines changed

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 89 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,44 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
391391
return false;
392392
}
393393

394+
// Assume the results have unreduced axes.
395+
//
396+
// Returns `AxesPerFactor` with only its reduction factors are populated to have
397+
// common axes.
398+
//
399+
// Hard fails if some reduction factors do not have compatible shardings.
400+
AxesPerFactor getCommonAxesPerReductionFactorOrDie(
401+
Operation* op, const ShardingProjection& shardingProjection,
402+
OpShardingRuleAttr shardingRule) {
403+
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
404+
// factors, and simplify the following logic.
405+
AxesPerFactor commonAxesPerFactor =
406+
AxesPerFactor(shardingRule.getNumFactors());
407+
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
408+
// We only iterate operands since reduction factors are not in results.
409+
bool seen = false;
410+
SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
411+
for (const TensorFactorShardings& tensorFactorSharding :
412+
shardingProjection.getOperands()) {
413+
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
414+
getFactorSharding(tensorFactorSharding, reductionFactor)) {
415+
SmallVector<AxisRefAttr> factorShardingVector =
416+
llvm::to_vector(*factorSharding);
417+
if (seen) {
418+
SDY_CHECK(factorShardingVector == commonAxes)
419+
<< "For the operation " << op
420+
<< ", the result has unreduced axes while the operand has "
421+
"incompatible sharding along reduction factors.";
422+
} else {
423+
commonAxes = factorShardingVector;
424+
seen = true;
425+
}
426+
}
427+
}
428+
}
429+
return commonAxesPerFactor;
430+
}
431+
394432
// Inserts explicit reshards on the operands and results of `op` such that the
395433
// sharding of `op` is compatible with its sharding rule.
396434
//
@@ -400,31 +438,25 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
400438
// - All op results have the same unreduced axes.
401439
// - If the op has no results, none of the operands has unreduced axes.
402440
// - Operand and result meshes are the same ignoring device id order.
441+
// - There are no overflow axes.
403442
//
404443
// Returns the union of axes along all the reduction factors which may not be
405444
// canonicalized.
406-
SmallVector<AxisRefAttr> processOp(Operation* op,
407-
ArrayRef<TensorShardingAttr> inShardings,
408-
ArrayRef<TensorShardingAttr> outShardings,
409-
IRRewriter& rewriter,
410-
const SymbolTable& symbolTable,
411-
OpShardingRuleAttr shardingRule,
412-
const Mesh& mesh, const bool onFullVersion) {
413-
ShardingProjection shardingProjection = ShardingProjection::build(
414-
inShardings, outShardings, shardingRule, mesh.attr(),
415-
/*closedIfMissing=*/true);
416-
417-
// Return without inserting reshards if any factor sharding has overflow
418-
// axes. This case is not handled yet.
419-
// TODO(enver): Handle the case when factor shardings have overflow axes.
420-
if (hasOverflowAxes(shardingProjection)) {
421-
return {};
422-
}
423-
445+
//
446+
// Guarantees to return non-empty `AxesPerFactor` if `onFullVersion` is true.
447+
AxesPerFactor processOp(Operation* op, ShardingProjection& shardingProjection,
448+
ArrayRef<TensorShardingAttr> inShardings,
449+
ArrayRef<TensorShardingAttr> outShardings,
450+
IRRewriter& rewriter, const SymbolTable& symbolTable,
451+
OpShardingRuleAttr shardingRule, const Mesh& mesh,
452+
const bool onFullVersion) {
453+
// Checks if factors are sharded the same way across operands and results.
454+
AxesPerFactor commonAxesPerFactor =
455+
getCompatibleFactorShardings(shardingProjection, shardingRule);
456+
457+
// TODO(b/446833985): Return common axes factors also when the sharding
458+
// projection have overflow axes.
424459
if (onFullVersion) {
425-
// Checks if factors are sharded the same way across operands and results.
426-
AxesPerFactor commonAxesPerFactor =
427-
getCompatibleFactorShardings(shardingProjection, shardingRule);
428460
// Find compatible shardings if it is not already compatible.
429461
if (commonAxesPerFactor.empty()) {
430462
commonAxesPerFactor =
@@ -443,49 +475,19 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
443475
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
444476
updateTensorShardings, rewriter, shardingRule,
445477
symbolTable, mesh);
446-
447-
return getReductionAxes(commonAxesPerFactor, shardingRule);
478+
} else {
479+
TypeSwitch<Operation*>(op)
480+
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
481+
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
482+
shardingRule, mesh);
483+
})
484+
.Case<stablehlo::DotGeneralOp>(
485+
[&](stablehlo::DotGeneralOp dotGeneralOp) {
486+
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
487+
symbolTable, shardingRule, mesh);
488+
});
448489
}
449-
450-
TypeSwitch<Operation*>(op)
451-
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452-
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
453-
shardingRule, mesh);
454-
})
455-
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456-
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
457-
symbolTable, shardingRule, mesh);
458-
});
459-
460-
if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
461-
return {};
462-
}
463-
464-
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465-
// factors, and simplify the following logic.
466-
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
467-
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
468-
// We only iterate operands since reduction factors are not in results.
469-
bool seen = false;
470-
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
471-
for (const TensorFactorShardings& tensorFactorSharding :
472-
shardingProjection.getOperands()) {
473-
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
474-
getFactorSharding(tensorFactorSharding, reductionFactor)) {
475-
if (seen) {
476-
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
477-
<< "For the operation " << op
478-
<< ", the result has unreduced axes while the operand has "
479-
"incompatible sharding along reduction factors.";
480-
} else {
481-
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
482-
seen = true;
483-
}
484-
}
485-
}
486-
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
487-
}
488-
return axesAlongAllReductionFactors;
490+
return commonAxesPerFactor;
489491
}
490492

491493
struct InsertExplicitReshardsPass
@@ -544,11 +546,31 @@ struct InsertExplicitReshardsPass
544546
return;
545547
}
546548

547-
SmallVector<AxisRefAttr> reductionAxes =
548-
processOp(op, inShardings, outShardings, rewriter, symbolTable,
549-
shardingRule, *mesh, onFullVersion);
549+
ShardingProjection shardingProjection = ShardingProjection::build(
550+
inShardings, outShardings, shardingRule, mesh->attr(),
551+
/*closedIfMissing=*/true);
552+
// Return without inserting reshards if any factor sharding has overflow
553+
// axes. This case is not handled yet.
554+
// TODO(enver): Handle the case when factor shardings have overflow axes.
555+
if (hasOverflowAxes(shardingProjection)) {
556+
return;
557+
}
558+
AxesPerFactor commonAxesPerFactor =
559+
processOp(op, shardingProjection, inShardings, outShardings, rewriter,
560+
symbolTable, shardingRule, *mesh, onFullVersion);
561+
if (outShardings.empty() ||
562+
(!onFullVersion && getUnreducedAxes(outShardings[0]).empty())) {
563+
return;
564+
}
565+
if (commonAxesPerFactor.empty()) {
566+
// At this point, there are unreduced axes on results.
567+
commonAxesPerFactor = getCommonAxesPerReductionFactorOrDie(
568+
op, shardingProjection, shardingRule);
569+
}
550570
// TODO(b/440055868): Insert a reshard from unreduced to replicated axes.
551-
insertAllReducesForReductionFactors(op, reductionAxes, *mesh, rewriter);
571+
insertAllReducesForReductionFactors(
572+
op, getReductionAxes(commonAxesPerFactor, shardingRule), *mesh,
573+
rewriter);
552574

553575
// TODO(enver): Remove sharding rules from ops.
554576
});

0 commit comments

Comments
 (0)