@@ -401,7 +401,8 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
401
401
// - If the op has no results, none of the operands has unreduced axes.
402
402
// - Operand and result meshes are the same ignoring device id order.
403
403
//
404
- // Returns the union of common reduction axes which may not be canonicalized.
404
+ // Returns the union of axes along all the reduction factors which may not be
405
+ // canonicalized.
405
406
SmallVector<AxisRefAttr> processOp (Operation* op,
406
407
ArrayRef<TensorShardingAttr> inShardings,
407
408
ArrayRef<TensorShardingAttr> outShardings,
@@ -413,17 +414,24 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
413
414
inShardings, outShardings, shardingRule, mesh.attr (),
414
415
/* closedIfMissing=*/ true );
415
416
417
+ // Return without inserting reshards if any factor sharding has overflow
418
+ // axes. This case is not handled yet.
419
+ // TODO(enver): Handle the case when factor shardings have overflow axes.
420
+ if (hasOverflowAxes (shardingProjection)) {
421
+ return {};
422
+ }
423
+
416
424
if (onFullVersion) {
417
- // Return without inserting reshards if any factor sharding has overflow
418
- // axes. This case is not handled yet.
419
- // TODO(b/446833985): Handle the case when factor shardings have overflow
420
- // axes.
421
- if (hasOverflowAxes (shardingProjection)) {
422
- return {};
423
- }
425
+ // Checks if factors are sharded the same way across operands and results.
424
426
AxesPerFactor commonAxesPerFactor =
425
- findCommonAxes (inShardings, outShardings, shardingProjection,
426
- shardingRule, getTensorSizes (op), symbolTable, mesh);
427
+ getCompatibleFactorShardings (shardingProjection, shardingRule);
428
+ // Find compatible shardings if it is not already compatible.
429
+ if (commonAxesPerFactor.empty ()) {
430
+ commonAxesPerFactor =
431
+ findCommonAxes (inShardings, outShardings, shardingProjection,
432
+ shardingRule, getTensorSizes (op), symbolTable, mesh);
433
+ }
434
+
427
435
UpdateTensorShardings updateTensorShardings (shardingRule.getNumOperands (),
428
436
shardingRule.getNumResults ());
429
437
for (const auto & [index, axes] : llvm::enumerate (commonAxesPerFactor)) {
@@ -453,8 +461,8 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
453
461
return {};
454
462
}
455
463
456
- // TODO(enver): Factor out finding common axes per factor. Share logic with
457
- // getCompatibleFactorShardings .
464
+ // TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465
+ // factors, and simplify the following logic .
458
466
SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
459
467
for (int64_t reductionFactor : shardingRule.getReductionFactors ()) {
460
468
// We only iterate operands since reduction factors are not in results.
0 commit comments