Skip to content

Commit b426498

Browse files
Check for overflow axes regardless of being default/minimal or full version.
Also refactor to run compatibility check also on the default/minimal version. It prepares to unify/simplify the logic to find reduction axes. PiperOrigin-RevId: 812733006
1 parent 30983ea commit b426498

File tree

3 files changed

+36
-30
lines changed

3 files changed

+36
-30
lines changed

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,8 @@ bool hasShardedPermutationFactors(
7171
!factorSharding.axisRefs.empty();
7272
});
7373
}
74+
} // namespace
7475

75-
// Checks if factor sharding is compatible, that is, it satisfies:
76-
// 1. Factors are sharded the same way across operands and results.
77-
// 2. Factors that need replication are unsharded.
78-
//
79-
// Returns the common axes per factor if the factor sharding is compatible.
80-
// Otherwise, returns empty AxesPerFactor.
81-
//
82-
// Assumes factor shardings do not have overflow axes.
8376
// TODO(enver): Handle the case when some factor shardings have overflow axes.
8477
AxesPerFactor getCompatibleFactorShardings(
8578
const ShardingProjection& shardingProjection,
@@ -112,6 +105,8 @@ AxesPerFactor getCompatibleFactorShardings(
112105
return commonAxesPerFactor;
113106
}
114107

108+
namespace {
109+
115110
void insertExplicitReshardsOnOperand(
116111
Operation* op, const int64_t operandIndex,
117112
const ShardingProjection& shardingProjection,
@@ -758,22 +753,12 @@ void distributeAxisRefsToBatchingFactors(
758753
}
759754
} // namespace
760755

761-
// Assumes there are no overflow axes.
762-
//
763-
// Guarantees to return a non-empty AxesPerFactor.
764756
AxesPerFactor findCommonAxes(ArrayRef<TensorShardingAttr> inShardings,
765757
ArrayRef<TensorShardingAttr> outShardings,
766758
const ShardingProjection& shardingProjection,
767759
OpShardingRuleAttr shardingRule,
768760
ArrayRef<int64_t> tensorSizes,
769761
const SymbolTable& symbolTable, const Mesh& mesh) {
770-
// Checks if factors are sharded the same way across operands and results.
771-
if (AxesPerFactor commonAxesPerFactor =
772-
getCompatibleFactorShardings(shardingProjection, shardingRule);
773-
!commonAxesPerFactor.empty()) {
774-
return commonAxesPerFactor;
775-
}
776-
777762
// Handle the special case of unary operations without factors that need
778763
// replication. Reshard only one of the tensors.
779764
if (shardingRule.getNonScalarTensorIndices().size() == 2 &&

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@ SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
8989
// Returns true iff any tensor factor sharding has non-empty overflow axes.
9090
bool hasOverflowAxes(const ShardingProjection& shardingProjection);
9191

92+
// Checks if factor sharding is compatible, that is, it satisfies:
93+
// 1. Factors are sharded the same way across operands and results.
94+
// 2. Factors that need replication are unsharded.
95+
//
96+
// Returns the common axes per factor if the factor sharding is compatible.
97+
// Otherwise, returns empty AxesPerFactor.
98+
//
99+
// Assumes factor shardings do not have overflow axes.
100+
AxesPerFactor getCompatibleFactorShardings(
101+
const ShardingProjection& shardingProjection,
102+
OpShardingRuleAttr shardingRule);
103+
92104
// Insert explicit reshards for operands and results that change by
93105
// the given `shardingProjection` for a given `op`. The reshards are inserted
94106
// only to make the given operation compatible.
@@ -154,6 +166,7 @@ void insertAllReducesForReductionFactors(Operation* op,
154166
// - If the op has no results, none of the operands has unreduced axes.
155167
// - Operand and result meshes are the same ignoring device id order.
156168
// - There are no overflow axes.
169+
// - Some shardings are not compatible.
157170
//
158171
// Guarantees to return a non-empty AxesPerFactor.
159172
AxesPerFactor findCommonAxes(ArrayRef<TensorShardingAttr> inShardings,

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,8 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
401401
// - If the op has no results, none of the operands has unreduced axes.
402402
// - Operand and result meshes are the same ignoring device id order.
403403
//
404-
// Returns the union of common reduction axes which may not be canonicalized.
404+
// Returns the union of axes along all the reduction factors which may not be
405+
// canonicalized.
405406
SmallVector<AxisRefAttr> processOp(Operation* op,
406407
ArrayRef<TensorShardingAttr> inShardings,
407408
ArrayRef<TensorShardingAttr> outShardings,
@@ -413,17 +414,24 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
413414
inShardings, outShardings, shardingRule, mesh.attr(),
414415
/*closedIfMissing=*/true);
415416

417+
// Return without inserting reshards if any factor sharding has overflow
418+
// axes. This case is not handled yet.
419+
// TODO(enver): Handle the case when factor shardings have overflow axes.
420+
if (hasOverflowAxes(shardingProjection)) {
421+
return {};
422+
}
423+
416424
if (onFullVersion) {
417-
// Return without inserting reshards if any factor sharding has overflow
418-
// axes. This case is not handled yet.
419-
// TODO(b/446833985): Handle the case when factor shardings have overflow
420-
// axes.
421-
if (hasOverflowAxes(shardingProjection)) {
422-
return {};
423-
}
425+
// Checks if factors are sharded the same way across operands and results.
424426
AxesPerFactor commonAxesPerFactor =
425-
findCommonAxes(inShardings, outShardings, shardingProjection,
426-
shardingRule, getTensorSizes(op), symbolTable, mesh);
427+
getCompatibleFactorShardings(shardingProjection, shardingRule);
428+
// Find compatible shardings if it is not already compatible.
429+
if (commonAxesPerFactor.empty()) {
430+
commonAxesPerFactor =
431+
findCommonAxes(inShardings, outShardings, shardingProjection,
432+
shardingRule, getTensorSizes(op), symbolTable, mesh);
433+
}
434+
427435
UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
428436
shardingRule.getNumResults());
429437
for (const auto& [index, axes] : llvm::enumerate(commonAxesPerFactor)) {
@@ -453,8 +461,8 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
453461
return {};
454462
}
455463

456-
// TODO(enver): Factor out finding common axes per factor. Share logic with
457-
// getCompatibleFactorShardings.
464+
// TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465+
// factors, and simplify the following logic.
458466
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
459467
for (int64_t reductionFactor : shardingRule.getReductionFactors()) {
460468
// We only iterate operands since reduction factors are not in results.

0 commit comments

Comments
 (0)