@@ -391,6 +391,44 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
391
391
return false ;
392
392
}
393
393
394
+ // Assume the results have unreduced axes.
395
+ //
396
+ // Returns `AxesPerFactor` with only its reduction factors are populated to have
397
+ // common axes.
398
+ //
399
+ // Hard fails if some reduction factors do not have compatible shardings.
400
+ AxesPerFactor getCommonAxesPerReductionFactorOrDie (
401
+ Operation* op, const ShardingProjection& shardingProjection,
402
+ OpShardingRuleAttr shardingRule) {
403
+ // TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
404
+ // factors, and simplify the following logic.
405
+ AxesPerFactor commonAxesPerFactor =
406
+ AxesPerFactor (shardingRule.getNumFactors ());
407
+ for (int64_t reductionFactor : shardingRule.getReductionFactors ()) {
408
+ // We only iterate operands since reduction factors are not in results.
409
+ bool seen = false ;
410
+ SmallVector<AxisRefAttr>& commonAxes = commonAxesPerFactor[reductionFactor];
411
+ for (const TensorFactorShardings& tensorFactorSharding :
412
+ shardingProjection.getOperands ()) {
413
+ if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
414
+ getFactorSharding (tensorFactorSharding, reductionFactor)) {
415
+ SmallVector<AxisRefAttr> factorShardingVector =
416
+ llvm::to_vector (*factorSharding);
417
+ if (seen) {
418
+ SDY_CHECK (factorShardingVector == commonAxes)
419
+ << " For the operation " << op
420
+ << " , the result has unreduced axes while the operand has "
421
+ " incompatible sharding along reduction factors." ;
422
+ } else {
423
+ commonAxes = factorShardingVector;
424
+ seen = true ;
425
+ }
426
+ }
427
+ }
428
+ }
429
+ return commonAxesPerFactor;
430
+ }
431
+
394
432
// Inserts explicit reshards on the operands and results of `op` such that the
395
433
// sharding of `op` is compatible with its sharding rule.
396
434
//
@@ -400,31 +438,25 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
400
438
// - All op results have the same unreduced axes.
401
439
// - If the op has no results, none of the operands has unreduced axes.
402
440
// - Operand and result meshes are the same ignoring device id order.
441
+ // - There are no overflow axes.
403
442
//
404
443
// Returns the union of axes along all the reduction factors which may not be
405
444
// canonicalized.
406
- SmallVector<AxisRefAttr> processOp (Operation* op,
407
- ArrayRef<TensorShardingAttr> inShardings,
408
- ArrayRef<TensorShardingAttr> outShardings,
409
- IRRewriter& rewriter,
410
- const SymbolTable& symbolTable,
411
- OpShardingRuleAttr shardingRule,
412
- const Mesh& mesh, const bool onFullVersion) {
413
- ShardingProjection shardingProjection = ShardingProjection::build (
414
- inShardings, outShardings, shardingRule, mesh.attr (),
415
- /* closedIfMissing=*/ true );
416
-
417
- // Return without inserting reshards if any factor sharding has overflow
418
- // axes. This case is not handled yet.
419
- // TODO(enver): Handle the case when factor shardings have overflow axes.
420
- if (hasOverflowAxes (shardingProjection)) {
421
- return {};
422
- }
423
-
445
+ //
446
+ // Guarantees to return non-empty `AxesPerFactor` if `onFullVersion` is true.
447
+ AxesPerFactor processOp (Operation* op, ShardingProjection& shardingProjection,
448
+ ArrayRef<TensorShardingAttr> inShardings,
449
+ ArrayRef<TensorShardingAttr> outShardings,
450
+ IRRewriter& rewriter, const SymbolTable& symbolTable,
451
+ OpShardingRuleAttr shardingRule, const Mesh& mesh,
452
+ const bool onFullVersion) {
453
+ // Checks if factors are sharded the same way across operands and results.
454
+ AxesPerFactor commonAxesPerFactor =
455
+ getCompatibleFactorShardings (shardingProjection, shardingRule);
456
+
457
+ // TODO(b/446833985): Return common axes factors also when the sharding
458
+ // projection have overflow axes.
424
459
if (onFullVersion) {
425
- // Checks if factors are sharded the same way across operands and results.
426
- AxesPerFactor commonAxesPerFactor =
427
- getCompatibleFactorShardings (shardingProjection, shardingRule);
428
460
// Find compatible shardings if it is not already compatible.
429
461
if (commonAxesPerFactor.empty ()) {
430
462
commonAxesPerFactor =
@@ -443,49 +475,19 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
443
475
insertExplicitReshards (op, inShardings, outShardings, shardingProjection,
444
476
updateTensorShardings, rewriter, shardingRule,
445
477
symbolTable, mesh);
446
-
447
- return getReductionAxes (commonAxesPerFactor, shardingRule);
478
+ } else {
479
+ TypeSwitch<Operation*>(op)
480
+ .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
481
+ processDot (dotOp, inShardings, outShardings, rewriter, symbolTable,
482
+ shardingRule, mesh);
483
+ })
484
+ .Case <stablehlo::DotGeneralOp>(
485
+ [&](stablehlo::DotGeneralOp dotGeneralOp) {
486
+ processDot (dotGeneralOp, inShardings, outShardings, rewriter,
487
+ symbolTable, shardingRule, mesh);
488
+ });
448
489
}
449
-
450
- TypeSwitch<Operation*>(op)
451
- .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452
- processDot (dotOp, inShardings, outShardings, rewriter, symbolTable,
453
- shardingRule, mesh);
454
- })
455
- .Case <stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456
- processDot (dotGeneralOp, inShardings, outShardings, rewriter,
457
- symbolTable, shardingRule, mesh);
458
- });
459
-
460
- if (outShardings.empty () || getUnreducedAxes (outShardings[0 ]).empty ()) {
461
- return {};
462
- }
463
-
464
- // TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465
- // factors, and simplify the following logic.
466
- SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
467
- for (int64_t reductionFactor : shardingRule.getReductionFactors ()) {
468
- // We only iterate operands since reduction factors are not in results.
469
- bool seen = false ;
470
- SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
471
- for (const TensorFactorShardings& tensorFactorSharding :
472
- shardingProjection.getOperands ()) {
473
- if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
474
- getFactorSharding (tensorFactorSharding, reductionFactor)) {
475
- if (seen) {
476
- SDY_CHECK (axesAlongCurrentReductionFactor == *factorSharding)
477
- << " For the operation " << op
478
- << " , the result has unreduced axes while the operand has "
479
- " incompatible sharding along reduction factors." ;
480
- } else {
481
- axesAlongCurrentReductionFactor = llvm::to_vector (*factorSharding);
482
- seen = true ;
483
- }
484
- }
485
- }
486
- axesAlongAllReductionFactors.append (axesAlongCurrentReductionFactor);
487
- }
488
- return axesAlongAllReductionFactors;
490
+ return commonAxesPerFactor;
489
491
}
490
492
491
493
struct InsertExplicitReshardsPass
@@ -544,11 +546,31 @@ struct InsertExplicitReshardsPass
544
546
return ;
545
547
}
546
548
547
- SmallVector<AxisRefAttr> reductionAxes =
548
- processOp (op, inShardings, outShardings, rewriter, symbolTable,
549
- shardingRule, *mesh, onFullVersion);
549
+ ShardingProjection shardingProjection = ShardingProjection::build (
550
+ inShardings, outShardings, shardingRule, mesh->attr (),
551
+ /* closedIfMissing=*/ true );
552
+ // Return without inserting reshards if any factor sharding has overflow
553
+ // axes. This case is not handled yet.
554
+ // TODO(enver): Handle the case when factor shardings have overflow axes.
555
+ if (hasOverflowAxes (shardingProjection)) {
556
+ return ;
557
+ }
558
+ AxesPerFactor commonAxesPerFactor =
559
+ processOp (op, shardingProjection, inShardings, outShardings, rewriter,
560
+ symbolTable, shardingRule, *mesh, onFullVersion);
561
+ if (outShardings.empty () ||
562
+ (!onFullVersion && getUnreducedAxes (outShardings[0 ]).empty ())) {
563
+ return ;
564
+ }
565
+ if (commonAxesPerFactor.empty ()) {
566
+ // At this point, there are unreduced axes on results.
567
+ commonAxesPerFactor = getCommonAxesPerReductionFactorOrDie (
568
+ op, shardingProjection, shardingRule);
569
+ }
550
570
// TODO(b/440055868): Insert a reshard from unreduced to replicated axes.
551
- insertAllReducesForReductionFactors (op, reductionAxes, *mesh, rewriter);
571
+ insertAllReducesForReductionFactors (
572
+ op, getReductionAxes (commonAxesPerFactor, shardingRule), *mesh,
573
+ rewriter);
552
574
553
575
// TODO(enver): Remove sharding rules from ops.
554
576
});
0 commit comments