@@ -158,18 +158,14 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
158158// return %reshard : tensor<4x8xf32>
159159// ```
160160template <class OpTy >
161- void processDot (OpTy op, ArrayRef<TensorShardingAttr> inShardings ,
161+ void processDot (OpTy op, ShardingProjection& shardingProjection ,
162162 ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
163163 const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
164164 const Mesh& mesh) {
165165 if (outShardings.empty ()) {
166166 // Result doesn't have a sharding.
167167 return ;
168168 }
169- ShardingProjection shardingProjection =
170- ShardingProjection::build (inShardings, outShardings, shardingRule,
171- mesh.attr (), /* closedIfMissing=*/ true );
172-
173169 const TensorFactorShardings& lhsSharding = shardingProjection.getOperand (0 );
174170 const TensorFactorShardings& rhsSharding = shardingProjection.getOperand (1 );
175171 TensorFactorShardings& resultSharding =
@@ -449,11 +445,11 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
449445
450446 TypeSwitch<Operation*>(op)
451447 .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452- processDot (dotOp, inShardings , outShardings, rewriter, symbolTable ,
453- shardingRule, mesh);
448+ processDot (dotOp, shardingProjection , outShardings, rewriter,
449+ symbolTable, shardingRule, mesh);
454450 })
455451 .Case <stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456- processDot (dotGeneralOp, inShardings , outShardings, rewriter,
452+ processDot (dotGeneralOp, shardingProjection , outShardings, rewriter,
457453 symbolTable, shardingRule, mesh);
458454 });
459455
0 commit comments