Skip to content

Commit 3b63be5

Browse files
ZixuanJiangcopybara-github
authored andcommitted
Move getPrefixWithoutOverlap from basic_factor_propagation.cc to ir/utils.h.
PiperOrigin-RevId: 814766867
1 parent 7870392 commit 3b63be5

File tree

7 files changed

+27
-33
lines changed

7 files changed

+27
-33
lines changed

shardy/dialect/sdy/ir/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ cc_library(
158158
":op_interface_inc",
159159
":ops_inc",
160160
"//shardy/common:logging",
161+
"//shardy/dialect/sdy/transforms/common:macros",
161162
"@llvm-project//llvm:Support",
162163
"@llvm-project//mlir:BytecodeOpInterface",
163164
"@llvm-project//mlir:FuncDialect",

shardy/dialect/sdy/ir/axis_list_ref.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,8 @@ class AxisListRef {
170170
// TODO(enver): Move this method to utilities.
171171
// TODO(enver): Instead make this a method of AxisRefAttr, after moving
172172
// AxesWithTail to a general data structure in Shardy.
173-
// TODO(enver): Reuse getPrefixOfInputWithout method on
174-
// shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc,
175-
// instead, after an iterater is added.
173+
// TODO(enver): Reuse getPrefixWithoutOverlap method in
174+
// shardy/dialect/sdy/ir/utils.h, after an iterator is added.
176175
std::optional<AxisRefAttr> getPrefixOfInputWithoutOverlap(
177176
AxisRefAttr axisRef) const;
178177

shardy/dialect/sdy/ir/utils.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ limitations under the License.
4545
#include "mlir/Support/LLVM.h"
4646
#include "shardy/dialect/sdy/ir/constants.h"
4747
#include "shardy/dialect/sdy/ir/dialect.h"
48+
#include "shardy/dialect/sdy/transforms/common/macros.h"
4849

4950
namespace mlir {
5051
namespace sdy {
@@ -631,5 +632,16 @@ bool isUsedBy(Value value, Operation* user) {
631632
});
632633
}
633634

635+
// TODO(enver): Use it in AxisListRef methods.
636+
std::optional<AxisRefAttr> getPrefixWithoutOverlap(
637+
AxisRefAttr axisRef, ArrayRef<AxisRefAttr> otherAxisRefs) {
638+
AxisRefAttr result = axisRef;
639+
for (AxisRefAttr otherAxisRef : otherAxisRefs) {
640+
SDY_ASSIGN_OR_RETURN_IF_NULLOPT(
641+
result, result.getPrefixWithoutOverlap(otherAxisRef));
642+
}
643+
return result;
644+
}
645+
634646
} // namespace sdy
635647
} // namespace mlir

shardy/dialect/sdy/ir/utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,11 @@ class AddAxisOrMergeInserter {
562562
// Returns true if `value` is used by `user`.
563563
bool isUsedBy(Value value, Operation* user);
564564

565+
// Returns the largest prefix of `axisRef` that does not overlap with any axes
566+
// in `otherAxisRefs`.
567+
std::optional<AxisRefAttr> getPrefixWithoutOverlap(
568+
AxisRefAttr axisRef, ArrayRef<AxisRefAttr> otherAxisRefs);
569+
565570
} // namespace sdy
566571
} // namespace mlir
567572

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -671,22 +671,19 @@ int64_t findTensorIndexToPreferOnUnaryOperation(
671671
//
672672
// Guarantees to return a non-empty AxesPerFactor.
673673
AxesPerFactor findCommonAxesOnUnaryOperation(
674-
ArrayRef<TensorShardingAttr> inShardings,
675-
ArrayRef<TensorShardingAttr> outShardings,
676674
const ShardingProjection& shardingProjection,
677675
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
678-
const SymbolTable& symbolTable, const Mesh& mesh) {
676+
const Mesh& mesh) {
679677
int64_t tensorIndexToPrefer = findTensorIndexToPreferOnUnaryOperation(
680678
shardingProjection, shardingRule, tensorSizes, mesh);
681679

682680
// Set factor shardings to make sure factors that do not appear in the
683681
// preferred tensor are sharded on the other tensor.
684682
AxesPerFactor factorAxisRefs(shardingRule.getNumFactors());
685683
// TODO(enver): Add and use forEachFactorSharding helper method.
686-
for (const auto& [tensorIndex, tensorFactorSharding] :
687-
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
688-
shardingProjection.getOperands(),
689-
shardingProjection.getResults()))) {
684+
for (const TensorFactorShardings& tensorFactorSharding :
685+
llvm::concat<const TensorFactorShardings>(
686+
shardingProjection.getOperands(), shardingProjection.getResults())) {
690687
for (const auto& [factorIndex, factorSharding] :
691688
tensorFactorSharding.factorIndexToSharding) {
692689
if (!factorSharding.axisRefs.empty()) {
@@ -764,9 +761,8 @@ AxesPerFactor findCommonAxes(ArrayRef<TensorShardingAttr> inShardings,
764761
if (shardingRule.getNonScalarTensorIndices().size() == 2 &&
765762
shardingRule.getNeedReplicationFactors().empty() &&
766763
!shardingRule.hasDimensionsWithMultipleFactors()) {
767-
return findCommonAxesOnUnaryOperation(inShardings, outShardings,
768-
shardingProjection, shardingRule,
769-
tensorSizes, symbolTable, mesh);
764+
return findCommonAxesOnUnaryOperation(shardingProjection, shardingRule,
765+
tensorSizes, mesh);
770766
}
771767

772768
AxesPerFactor factorCommonAxes = findCommonAxesUsingMajorityVoteHeuristic(

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ ArrayRef<AxisRefAttr> getUnreducedAxes(Value value);
8282
SmallVector<int64_t> getTensorSizes(Operation* op);
8383

8484
// Returns reduction axes that are the union of all axes on reduction factors.
85-
// The result axes are not necessarilly canonicalized.
85+
// The result axes are not necessarily canonicalized.
8686
SmallVector<AxisRefAttr> getReductionAxes(const AxesPerFactor& axesPerFactor,
8787
OpShardingRuleAttr shardingRule);
8888

shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ limitations under the License.
2727

2828
#include "llvm/ADT/STLExtras.h"
2929
#include "llvm/ADT/SmallVector.h"
30-
#include "llvm/Support/FormatVariadic.h"
31-
#include "llvm/Support/Threading.h"
3230
#include "mlir/IR/Diagnostics.h"
3331
#include "mlir/IR/Value.h"
3432
#include "mlir/Support/LLVM.h"
@@ -41,23 +39,6 @@ limitations under the License.
4139
namespace mlir {
4240
namespace sdy {
4341

44-
namespace {
45-
46-
// Returns the largest prefix of `axisRef` that does not overlap with any axes
47-
// in `otherAxisRefs`.
48-
// TODO(enver): Move to ir/utils and use in AxisListRef methods.
49-
std::optional<AxisRefAttr> getPrefixWithoutOverlap(
50-
AxisRefAttr axisRef, ArrayRef<AxisRefAttr> otherAxisRefs) {
51-
AxisRefAttr result = axisRef;
52-
for (AxisRefAttr otherAxisRef : otherAxisRefs) {
53-
SDY_ASSIGN_OR_RETURN_IF_NULLOPT(
54-
result, result.getPrefixWithoutOverlap(otherAxisRef));
55-
}
56-
return result;
57-
}
58-
59-
} // namespace
60-
6142
std::optional<AxisRefAttr>
6243
BasicFactorPropagation::compatiblePrefixNoConflictsAcrossFactors(
6344
AxisRefAttr axisRef, const FactorIndexToSharding& factorIndexToSharding,

0 commit comments

Comments
 (0)