@@ -573,7 +573,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
573
573
for (int i = 0 ; i < nodes.size (); ++i) {
574
574
Node &cur = nodes[i];
575
575
Node &next = nodes[(i + 1 ) % nodes.size ()];
576
- if (!samePartition (inBody ( cur.op ), inBody ( next.op ) )) {
576
+ if (!samePartition (cur.op , next.op )) {
577
577
cur.barNext = createBarrierAlloc (loop, numMmaStages);
578
578
next.barPrev = cur.barNext ;
579
579
}
@@ -616,6 +616,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
616
616
continue ;
617
617
b.setInsertionPoint (node.op );
618
618
Value view = createSingleBufferView (b, allocOp, node.index );
619
+ b.assignPartition (view.getDefiningOp (), *partitions.getPartition (node.op ));
619
620
if (auto storeOp = dyn_cast<ttng::TMEMStoreOp>(node.op )) {
620
621
storeOp.getDstMutable ().assign (view);
621
622
storeOp.getDepMutable ().clear ();
@@ -671,7 +672,6 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
671
672
Operation *defOp = operand.getDefiningOp ();
672
673
if (!defOp || loop.isDefinedOutsideOfLoop (operand))
673
674
continue ;
674
- defOp = inBody (defOp);
675
675
676
676
if (partitions.isInRootPartition (defOp)) {
677
677
// If the MMA operand is coming from outside the loop, move the alloc out.
@@ -717,7 +717,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
717
717
}
718
718
719
719
for (Node &node : nodes) {
720
- Partition *partition = partitions.getPartition (inBody ( node.op ) );
720
+ Partition *partition = partitions.getPartition (node.op );
721
721
PartitionBuilder b (node.op ->getLoc (), loop);
722
722
723
723
SmallVector<Operation *> defs;
@@ -746,11 +746,13 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
746
746
domInfo.properlyDominates (mmaOp, userPred.getDefiningOp ())) {
747
747
b.restoreInsertionPoint (*incrementPt);
748
748
Value bar = createSingleBufferView (b, node.barPrev , curIndex);
749
+ b.assignPartition (bar.getDefiningOp (), *partition);
749
750
b.createInto <ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
750
751
curPhase, userPred);
751
752
} else {
752
753
b.setInsertionPoint (domOp);
753
754
Value bar = createSingleBufferView (b, node.barPrev , node.index );
755
+ b.assignPartition (bar.getDefiningOp (), *partition);
754
756
b.createInto <ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
755
757
node.phase , userPred);
756
758
}
@@ -759,6 +761,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
759
761
if (isa<scf::IfOp>(domOp->getParentOp ()) && accIsMultiBuffered)
760
762
b.setInsertionPointToStart (domOp->getBlock ());
761
763
Value bar = createSingleBufferView (b, node.barPrev , node.index );
764
+ b.assignPartition (bar.getDefiningOp (), *partition);
762
765
b.createInto <ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
763
766
node.phase );
764
767
}
@@ -767,13 +770,15 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
767
770
if (mmaOp == node.op ) {
768
771
b.setInsertionPoint (mmaOp);
769
772
Value bar = createSingleBufferView (b, node.barNext , node.index );
773
+ b.assignPartition (bar.getDefiningOp (), *partitions.getPartition (mmaOp));
770
774
mmaOp.addCompletionBarrier (bar, userPred);
771
775
mmaOp.setIsAsync (true );
772
776
} else {
773
777
b.setInsertionPointAfter (lastOp);
774
778
if (isa<scf::IfOp>(lastOp->getParentOp ()) && accIsMultiBuffered)
775
779
b.setInsertionPoint (lastOp->getBlock ()->getTerminator ());
776
780
Value bar = createSingleBufferView (b, node.barNext , node.index );
781
+ b.assignPartition (bar.getDefiningOp (), *partition);
777
782
b.createInto <ttng::ArriveBarrierOp>(*partition, nodeStageCluster, bar,
778
783
1 );
779
784
}
@@ -799,20 +804,26 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
799
804
StageCluster srcStageCluster = getStageCluster (domOp);
800
805
b.setInsertionPoint (domOp);
801
806
Value emptyView = createSingleBufferView (b, emptyBar, index);
807
+ b.assignPartition (emptyView.getDefiningOp (), *partition);
802
808
b.createInto <ttng::WaitBarrierOp>(*partition, srcStageCluster, emptyView,
803
809
phase);
804
810
805
811
b.setInsertionPointAfter (lastOp);
806
812
Value readyView = createSingleBufferView (b, readyBar, index);
813
+ b.assignPartition (readyView.getDefiningOp (), *partition);
807
814
b.createInto <ttng::ArriveBarrierOp>(*partition, srcStageCluster, readyView,
808
815
1 );
809
816
810
817
b.setInsertionPoint (mmaOp);
811
818
Value readyView2 = createSingleBufferView (b, readyBar, index);
819
+ b.assignPartition (readyView2.getDefiningOp (),
820
+ *partitions.getPartition (mmaOp));
812
821
b.createInto <ttng::WaitBarrierOp>(*partitions.getPartition (mmaOp),
813
822
getStageCluster (mmaOp), readyView2,
814
823
phase);
815
824
Value emptyView2 = createSingleBufferView (b, emptyBar, index);
825
+ b.assignPartition (emptyView2.getDefiningOp (),
826
+ *partitions.getPartition (mmaOp));
816
827
mmaOp.addCompletionBarrier (emptyView2, b.boolCst (true ));
817
828
mmaOp.setIsAsync (true );
818
829
}
0 commit comments