Skip to content

Commit 7870392

Browse files
Refactor to use already built sharding projection on process dot.
PiperOrigin-RevId: 814658702
1 parent c63d0c5 commit 7870392

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,14 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
158158
// return %reshard : tensor<4x8xf32>
159159
// ```
160160
template <class OpTy>
161-
void processDot(OpTy op, ArrayRef<TensorShardingAttr> inShardings,
161+
void processDot(OpTy op, ShardingProjection& shardingProjection,
162162
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
163163
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
164164
const Mesh& mesh) {
165165
if (outShardings.empty()) {
166166
// Result doesn't have a sharding.
167167
return;
168168
}
169-
ShardingProjection shardingProjection =
170-
ShardingProjection::build(inShardings, outShardings, shardingRule,
171-
mesh.attr(), /*closedIfMissing=*/true);
172-
173169
const TensorFactorShardings& lhsSharding = shardingProjection.getOperand(0);
174170
const TensorFactorShardings& rhsSharding = shardingProjection.getOperand(1);
175171
TensorFactorShardings& resultSharding =
@@ -449,11 +445,11 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
449445

450446
TypeSwitch<Operation*>(op)
451447
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452-
processDot(dotOp, inShardings, outShardings, rewriter, symbolTable,
453-
shardingRule, mesh);
448+
processDot(dotOp, shardingProjection, outShardings, rewriter,
449+
symbolTable, shardingRule, mesh);
454450
})
455451
.Case<stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456-
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
452+
processDot(dotGeneralOp, shardingProjection, outShardings, rewriter,
457453
symbolTable, shardingRule, mesh);
458454
});
459455

0 commit comments

Comments
 (0)