Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 63 additions & 43 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,42 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
return false;
}

// Assume the results have unreduced axes.
//
// Returns `AxesPerFactor` with only its reduction factors are populated to have
// common axes.
//
// Hard fails if some reduction factors do not have compatible shardings.
AxesPerFactor getCommonAxesPerReductionFactor(
Operation* op, const ShardingProjection& shardingProjection,
OpShardingRuleAttr shardingRule) {
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
// factors, and simplify the following logic.
AxesPerFactor commonAxesPerFactor =
AxesPerFactor(shardingRule.getNumFactors());
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
// We only iterate operands since reduction factors are not in results.
bool seen = false;
SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
for (const TensorFactorShardings& tensorFactorSharding :
shardingProjection.getOperands()) {
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
getFactorSharding(tensorFactorSharding, reductionFactor)) {
if (seen) {
SDY_CHECK(factorSharding->equals(commonAxes))
<< "For the operation " << op
<< ", the result has unreduced axes while the operand has "
"incompatible sharding along reduction factors.";
} else {
commonAxes = llvm::to_vector(*factorSharding);
seen = true;
}
}
}
}
return commonAxesPerFactor;
}

// Inserts explicit reshards on the operands and results of `op` such that the
// sharding of `op` is compatible with its sharding rule.
//
Expand Down Expand Up @@ -417,10 +453,13 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
return {};
}

// Checks if factors are sharded the same way across operands and results.
AxesPerFactor commonAxesPerFactor =
getCompatibleFactorShardings(shardingProjection, shardingRule);

// TODO(b/446833985): Return common axes per factor also when the sharding
// projection have overflow axes.
if (onFullVersion) {
// Checks if factors are sharded the same way across operands and results.
AxesPerFactor commonAxesPerFactor =
getCompatibleFactorShardings(shardingProjection, shardingRule);
// Find compatible shardings if it is not already compatible.
if (commonAxesPerFactor.empty()) {
commonAxesPerFactor =
Expand All @@ -439,49 +478,30 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
insertExplicitReshards(op, inShardings, outShardings, shardingProjection,
updateTensorShardings, rewriter, shardingRule,
symbolTable, mesh);
} else {
TypeSwitch<Operation*>(op)
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
processDot(dotOp, shardingProjection, outShardings, rewriter,
symbolTable, shardingRule, mesh);
})
.Case<stablehlo::DotGeneralOp>(
[&](stablehlo::DotGeneralOp dotGeneralOp) {
processDot(dotGeneralOp, shardingProjection, outShardings,
rewriter, symbolTable, shardingRule, mesh);
});

if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
return {};
}

return getReductionAxes(commonAxesPerFactor, shardingRule);
}

TypeSwitch<Operation*>(op)
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
processDot(dotOp, shardingProjection, outShardings, rewriter,
symbolTable, shardingRule, mesh);
})
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
processDot(dotGeneralOp, shardingProjection, outShardings, rewriter,
symbolTable, shardingRule, mesh);
});

if (outShardings.empty() || getUnreducedAxes(outShardings[0]).empty()) {
return {};
}

// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
// factors, and simplify the following logic.
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
// We only iterate operands since reduction factors are not in results.
bool seen = false;
SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
for (const TensorFactorShardings& tensorFactorSharding :
shardingProjection.getOperands()) {
if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
getFactorSharding(tensorFactorSharding, reductionFactor)) {
if (seen) {
SDY_CHECK(axesAlongCurrentReductionFactor == *factorSharding)
<< "For the operation " << op
<< ", the result has unreduced axes while the operand has "
"incompatible sharding along reduction factors.";
} else {
axesAlongCurrentReductionFactor = llvm::to_vector(*factorSharding);
seen = true;
}
}
if (commonAxesPerFactor.empty()) {
// At this point, there are unreduced axes on results.
commonAxesPerFactor =
getCommonAxesPerReductionFactor(op, shardingProjection, shardingRule);
}
axesAlongAllReductionFactors.append(axesAlongCurrentReductionFactor);
}
return axesAlongAllReductionFactors;

return getReductionAxes(commonAxesPerFactor, shardingRule);
}

struct InsertExplicitReshardsPass
Expand Down
Loading