Skip to content

Commit d84fe5b

Browse files
Move logic related to allreduce out of process op method.
It is a refactoring. It prepares to move some of it to allreduce logic. Also move the check for overflow axes out. This way, it guarantees to return non-empty AxesPerFactor in the full version. PiperOrigin-RevId: 812750842
1 parent d5b0f10 commit d84fe5b

File tree

1 file changed

+92
-72
lines changed

1 file changed

+92
-72
lines changed

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,14 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
158158
// return %reshard : tensor<4x8xf32>
159159
// ```
160160
template <class OpTy>
161-
void processDot(OpTy op, ArrayRef<TensorShardingAttr> inShardings,
161+
void processDot(OpTy op, ShardingProjection& shardingProjection,
162162
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
163163
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
164164
const Mesh& mesh) {
165165
if (outShardings.empty()) {
166166
// Result doesn't have a sharding.
167167
return;
168168
}
169-
ShardingProjection shardingProjection =
170-
ShardingProjection::build(inShardings, outShardings, shardingRule,
171-
mesh.attr(), /*closedIfMissing=*/true);
172-
173169
const TensorFactorShardings& lhsSharding = shardingProjection.getOperand(0);
174170
const TensorFactorShardings& rhsSharding = shardingProjection.getOperand(1);
175171
TensorFactorShardings& resultSharding =
@@ -391,6 +387,42 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
391387
return false;
392388
}
393389

390+
// Assume the results have unreduced axes.
391+
//
392+
// Returns `AxesPerFactor` with only its reduction factors are populated to have
393+
// common axes.
394+
//
395+
// Hard fails if some reduction factors do not have compatible shardings.
396+
AxesPerFactor getCommonAxesPerReductionFactor(
397+
Operation* op, const ShardingProjection& shardingProjection,
398+
OpShardingRuleAttr shardingRule) {
399+
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
400+
// factors, and simplify the following logic.
401+
AxesPerFactor commonAxesPerFactor =
402+
AxesPerFactor(shardingRule.getNumFactors());
403+
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
404+
// We only iterate operands since reduction factors are not in results.
405+
bool seen = false;
406+
SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
407+
for (const TensorFactorShardings& tensorFactorSharding :
408+
shardingProjection.getOperands()) {
409+
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
410+
getFactorSharding(tensorFactorSharding, reductionFactor)) {
411+
if (seen) {
412+
SDY_CHECK(factorSharding->equals(commonAxes))
413+
<< "For the operation " << op
414+
<< ", the result has unreduced axes while the operand has "
415+
"incompatible sharding along reduction factors.";
416+
} else {
417+
commonAxes = llvm::to_vector(*factorSharding);
418+
seen = true;
419+
}
420+
}
421+
}
422+
}
423+
return commonAxesPerFactor;
424+
}
425+
394426
// Inserts explicit reshards on the operands and results of `op` such that the
395427
// sharding of `op` is compatible with its sharding rule.
396428
//
@@ -400,31 +432,25 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
400432
// - All op results have the same unreduced axes.
401433
// - If the op has no results, none of the operands has unreduced axes.
402434
// - Operand and result meshes are the same ignoring device id order.
435+
// - There are no overflow axes.
403436
//
404437
// Returns the union of axes along all the reduction factors which may not be
405438
// canonicalized.
406-
SmallVector<AxisRefAttr> processOp(Operation* op,
407-
ArrayRef<TensorShardingAttr> inShardings,
408-
ArrayRef<TensorShardingAttr> outShardings,
409-
IRRewriter& rewriter,
410-
const SymbolTable& symbolTable,
411-
OpShardingRuleAttr shardingRule,
412-
const Mesh& mesh, const bool onFullVersion) {
413-
ShardingProjection shardingProjection = ShardingProjection::build(
414-
inShardings, outShardings, shardingRule, mesh.attr(),
415-
/*closedIfMissing=*/true);
416-
417-
// Return without inserting reshards if any factor sharding has overflow
418-
// axes. This case is not handled yet.
419-
// TODO(enver): Handle the case when factor shardings have overflow axes.
420-
if (hasOverflowAxes(shardingProjection)) {
421-
return {};
422-
}
423-
439+
//
440+
// Guarantees to return non-empty `AxesPerFactor` if `onFullVersion` is true.
441+
AxesPerFactor processOp(Operation* op, ShardingProjection& shardingProjection,
442+
ArrayRef<TensorShardingAttr> inShardings,
443+
ArrayRef<TensorShardingAttr> outShardings,
444+
IRRewriter& rewriter, const SymbolTable& symbolTable,
445+
OpShardingRuleAttr shardingRule, const Mesh& mesh,
446+
const bool onFullVersion) {
447+
// Checks if factors are sharded the same way across operands and results.
448+
AxesPerFactor commonAxesPerFactor =
449+
getCompatibleFactorShardings(shardingProjection, shardingRule);
450+
451+
// TODO(b/446833985): Return common axes per factor also when the sharding
452+
// projection have overflow axes.
424453
if (onFullVersion) {
425-
// Checks if factors are sharded the same way across operands and results.
426-
AxesPerFactor commonAxesPerFactor =
427-
getCompatibleFactorShardings(shardingProjection, shardingRule);
428454
// Find compatible shardings if it is not already compatible.
429455
if (commonAxesPerFactor.empty()) {
430456
commonAxesPerFactor =
@@ -443,49 +469,19 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
443469
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
444470
updateTensorShardings, rewriter, shardingRule,
445471
symbolTable, mesh);
446-
447-
return getReductionAxes(commonAxesPerFactor, shardingRule);
448-
}
449-
450-
TypeSwitch<Operation*>(op)
451-
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452-
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
453-
shardingRule, mesh);
454-
})
455-
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456-
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
457-
symbolTable, shardingRule, mesh);
458-
});
459-
460-
if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
461-
return {};
472+
} else {
473+
TypeSwitch<Operation*>(op)
474+
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
475+
processDot(dotOp, shardingProjection, outShardings, rewriter,
476+
symbolTable, shardingRule, mesh);
477+
})
478+
.Case<stablehlo::DotGeneralOp>(
479+
[&](stablehlo::DotGeneralOp dotGeneralOp) {
480+
processDot(dotGeneralOp, shardingProjection, outShardings,
481+
rewriter, symbolTable, shardingRule, mesh);
482+
});
462483
}
463-
464-
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465-
// factors, and simplify the following logic.
466-
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
467-
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
468-
// We only iterate operands since reduction factors are not in results.
469-
bool seen = false;
470-
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
471-
for (const TensorFactorShardings& tensorFactorSharding :
472-
shardingProjection.getOperands()) {
473-
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
474-
getFactorSharding(tensorFactorSharding, reductionFactor)) {
475-
if (seen) {
476-
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
477-
<< "For the operation " << op
478-
<< ", the result has unreduced axes while the operand has "
479-
"incompatible sharding along reduction factors.";
480-
} else {
481-
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
482-
seen = true;
483-
}
484-
}
485-
}
486-
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
487-
}
488-
return axesAlongAllReductionFactors;
484+
return commonAxesPerFactor;
489485
}
490486

491487
struct InsertExplicitReshardsPass
@@ -544,11 +540,35 @@ struct InsertExplicitReshardsPass
544540
return;
545541
}
546542

547-
SmallVector<AxisRefAttr> reductionAxes =
548-
processOp(op, inShardings, outShardings, rewriter, symbolTable,
549-
shardingRule, *mesh, onFullVersion);
543+
ShardingProjection shardingProjection = ShardingProjection::build(
544+
inShardings, outShardings, shardingRule, mesh->attr(),
545+
/*closedIfMissing=*/true);
546+
// Return without inserting reshards if any factor sharding has overflow
547+
// axes. This case is not handled yet.
548+
// TODO(enver): Handle the case when factor shardings have overflow axes.
549+
if (hasOverflowAxes(shardingProjection)) {
550+
return;
551+
}
552+
AxesPerFactor commonAxesPerFactor =
553+
processOp(op, shardingProjection, inShardings, outShardings, rewriter,
554+
symbolTable, shardingRule, *mesh, onFullVersion);
555+
if (op->getResults().empty()) {
556+
return;
557+
}
558+
if (!onFullVersion) {
559+
if (getUnreducedAxes(op->getResult(0)).empty()) {
560+
return;
561+
}
562+
if (commonAxesPerFactor.empty()) {
563+
// At this point, there are unreduced axes on results.
564+
commonAxesPerFactor = getCommonAxesPerReductionFactor(
565+
op, shardingProjection, shardingRule);
566+
}
567+
}
550568
// TODO(b/440055868): Insert a reshard from unreduced to replicated axes.
551-
insertAllReducesForReductionFactors(op, reductionAxes, *mesh, rewriter);
569+
insertAllReducesForReductionFactors(
570+
op, getReductionAxes(commonAxesPerFactor, shardingRule), *mesh,
571+
rewriter);
552572

553573
// TODO(enver): Remove sharding rules from ops.
554574
});

0 commit comments

Comments
 (0)