@@ -238,9 +238,13 @@ struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
238
238
239
239
Value difference = rewriter.create <SubIOp>(op.getLoc (), op.getUpperBound (),
240
240
op.getLowerBound ());
241
- Value tripCount = rewriter.create <AddIOp>(op.getLoc (), rewriter.create <DivUIOp>(op.getLoc (),
242
- rewriter.create <SubIOp>(op.getLoc (), difference, one), op.getStep ()), one);
243
- // rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
241
+ Value tripCount = rewriter.create <AddIOp>(
242
+ op.getLoc (),
243
+ rewriter.create <DivUIOp>(
244
+ op.getLoc (), rewriter.create <SubIOp>(op.getLoc (), difference, one),
245
+ op.getStep ()),
246
+ one);
247
+ // rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
244
248
auto newForOp =
245
249
rewriter.create <scf::ForOp>(op.getLoc (), zero, tripCount, one);
246
250
rewriter.setInsertionPointToStart (newForOp.getBody ());
@@ -455,34 +459,38 @@ static void moveBodies(PatternRewriter &rewriter, scf::ParallelOp op,
455
459
scf::IfOp ifOp, scf::IfOp newIf) {
456
460
rewriter.startRootUpdate (op);
457
461
{
458
- OpBuilder::InsertionGuard guard (rewriter);
459
- rewriter.setInsertionPointToStart (newIf.thenBlock ());
460
- auto newParallel = rewriter.create <scf::ParallelOp>(
461
- op.getLoc (), op.getLowerBound (), op.getUpperBound (), op.getStep ());
462
+ OpBuilder::InsertionGuard guard (rewriter);
463
+ rewriter.setInsertionPointToStart (newIf.thenBlock ());
464
+ auto newParallel = rewriter.create <scf::ParallelOp>(
465
+ op.getLoc (), op.getLowerBound (), op.getUpperBound (), op.getStep ());
462
466
463
- for (auto tup : llvm::zip (newParallel.getInductionVars (), op.getInductionVars ())) {
467
+ for (auto tup :
468
+ llvm::zip (newParallel.getInductionVars (), op.getInductionVars ())) {
464
469
std::get<1 >(tup).replaceUsesWithIf (std::get<0 >(tup), [&](OpOperand &op) {
465
- return ifOp.getThenRegion ().isAncestor (op.getOwner ()->getParentRegion ());
470
+ return ifOp.getThenRegion ().isAncestor (
471
+ op.getOwner ()->getParentRegion ());
466
472
});
467
- }
473
+ }
468
474
469
- rewriter.mergeBlockBefore (ifOp.thenBlock (), &newParallel.getBody ()->back ());
470
- rewriter.eraseOp (&newParallel.getBody ()->back ());
475
+ rewriter.mergeBlockBefore (ifOp.thenBlock (), &newParallel.getBody ()->back ());
476
+ rewriter.eraseOp (&newParallel.getBody ()->back ());
471
477
}
472
478
473
479
if (ifOp.getElseRegion ().getBlocks ().size () > 0 ) {
474
- OpBuilder::InsertionGuard guard (rewriter);
475
- rewriter.setInsertionPointToStart (newIf.elseBlock ());
476
- auto newParallel = rewriter.create <scf::ParallelOp>(
477
- op.getLoc (), op.getLowerBound (), op.getUpperBound (), op.getStep ());
480
+ OpBuilder::InsertionGuard guard (rewriter);
481
+ rewriter.setInsertionPointToStart (newIf.elseBlock ());
482
+ auto newParallel = rewriter.create <scf::ParallelOp>(
483
+ op.getLoc (), op.getLowerBound (), op.getUpperBound (), op.getStep ());
478
484
479
- for (auto tup : llvm::zip (newParallel.getInductionVars (), op.getInductionVars ())) {
485
+ for (auto tup :
486
+ llvm::zip (newParallel.getInductionVars (), op.getInductionVars ())) {
480
487
std::get<1 >(tup).replaceUsesWithIf (std::get<0 >(tup), [&](OpOperand &op) {
481
- return ifOp.getElseRegion ().isAncestor (op.getOwner ()->getParentRegion ());
488
+ return ifOp.getElseRegion ().isAncestor (
489
+ op.getOwner ()->getParentRegion ());
482
490
});
483
- }
484
- rewriter.mergeBlockBefore (ifOp.elseBlock (), &newParallel.getBody ()->back ());
485
- rewriter.eraseOp (&newParallel.getBody ()->back ());
491
+ }
492
+ rewriter.mergeBlockBefore (ifOp.elseBlock (), &newParallel.getBody ()->back ());
493
+ rewriter.eraseOp (&newParallel.getBody ()->back ());
486
494
}
487
495
488
496
rewriter.eraseOp (ifOp);
@@ -518,8 +526,9 @@ struct InterchangeIfPFor : public OpRewritePattern<scf::ParallelOp> {
518
526
return failure ();
519
527
}
520
528
521
- auto newIf =
522
- rewriter.create <scf::IfOp>(ifOp.getLoc (), TypeRange (), ifOp.getCondition (), ifOp.getElseRegion ().getBlocks ().size () > 0 );
529
+ auto newIf = rewriter.create <scf::IfOp>(
530
+ ifOp.getLoc (), TypeRange (), ifOp.getCondition (),
531
+ ifOp.getElseRegion ().getBlocks ().size () > 0 );
523
532
moveBodies (rewriter, op, ifOp, newIf);
524
533
return success ();
525
534
}
@@ -563,9 +572,10 @@ struct InterchangeIfPForLoad : public OpRewritePattern<scf::ParallelOp> {
563
572
Value condition = rewriter.create <memref::LoadOp>(
564
573
loadOp.getLoc (), loadOp.getMemRef (),
565
574
SmallVector<Value>(loadOp.getMemRefType ().getRank (), zero));
566
-
575
+
567
576
auto newIf =
568
- rewriter.create <scf::IfOp>(ifOp.getLoc (), TypeRange (), condition, ifOp.getElseRegion ().getBlocks ().size () > 0 );
577
+ rewriter.create <scf::IfOp>(ifOp.getLoc (), TypeRange (), condition,
578
+ ifOp.getElseRegion ().getBlocks ().size () > 0 );
569
579
moveBodies (rewriter, op, ifOp, newIf);
570
580
return success ();
571
581
}
@@ -1072,9 +1082,11 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
1072
1082
allocated.reserve (op.getNumIterOperands ());
1073
1083
for (Value operand : op.getIterOperands ()) {
1074
1084
Value alloc = rewriter.create <memref::AllocaOp>(
1075
- op.getLoc (), MemRefType::get (ArrayRef<int64_t >(), operand.getType ()), ValueRange ());
1085
+ op.getLoc (), MemRefType::get (ArrayRef<int64_t >(), operand.getType ()),
1086
+ ValueRange ());
1076
1087
allocated.push_back (alloc);
1077
- rewriter.create <memref::StoreOp>(op.getLoc (), operand, alloc, ValueRange ());
1088
+ rewriter.create <memref::StoreOp>(op.getLoc (), operand, alloc,
1089
+ ValueRange ());
1078
1090
}
1079
1091
1080
1092
auto newOp = rewriter.create <scf::ForOp>(op.getLoc (), op.getLowerBound (),
@@ -1098,7 +1110,8 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
1098
1110
rewriter.setInsertionPointAfter (op);
1099
1111
SmallVector<Value> loaded;
1100
1112
for (Value alloc : allocated) {
1101
- loaded.push_back (rewriter.create <memref::LoadOp>(op.getLoc (), alloc, ValueRange ()));
1113
+ loaded.push_back (
1114
+ rewriter.create <memref::LoadOp>(op.getLoc (), alloc, ValueRange ()));
1102
1115
}
1103
1116
rewriter.replaceOp (op, loaded);
1104
1117
return success ();
@@ -1112,25 +1125,26 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
1112
1125
PatternRewriter &rewriter) const override {
1113
1126
if (!op.getResults ().size () || !hasNestedBarrier (op))
1114
1127
return failure ();
1115
-
1116
1128
1117
1129
SmallVector<Value> allocated;
1118
1130
allocated.reserve (op.getNumResults ());
1119
1131
for (Type opType : op.getResultTypes ()) {
1120
1132
Value alloc = rewriter.create <memref::AllocaOp>(
1121
- op.getLoc (), MemRefType::get (ArrayRef<int64_t >(), opType), ValueRange ());
1133
+ op.getLoc (), MemRefType::get (ArrayRef<int64_t >(), opType),
1134
+ ValueRange ());
1122
1135
allocated.push_back (alloc);
1123
1136
}
1124
-
1125
- auto newOp = rewriter.create <scf::IfOp>(op.getLoc (), TypeRange (), op.getCondition (), true );
1137
+
1138
+ auto newOp = rewriter.create <scf::IfOp>(op.getLoc (), TypeRange (),
1139
+ op.getCondition (), true );
1126
1140
1127
1141
rewriter.setInsertionPoint (op.thenYield ());
1128
1142
for (auto en : llvm::enumerate (op.thenYield ().getOperands ())) {
1129
1143
rewriter.create <memref::StoreOp>(op.getLoc (), en.value (),
1130
1144
allocated[en.index ()], ValueRange ());
1131
1145
}
1132
1146
op.thenYield ()->setOperands (ValueRange ());
1133
-
1147
+
1134
1148
rewriter.setInsertionPoint (op.elseYield ());
1135
1149
for (auto en : llvm::enumerate (op.elseYield ().getOperands ())) {
1136
1150
rewriter.create <memref::StoreOp>(op.getLoc (), en.value (),
@@ -1140,14 +1154,15 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
1140
1154
1141
1155
rewriter.eraseOp (&newOp.thenBlock ()->back ());
1142
1156
rewriter.mergeBlocks (op.thenBlock (), newOp.thenBlock ());
1143
-
1157
+
1144
1158
rewriter.eraseOp (&newOp.elseBlock ()->back ());
1145
1159
rewriter.mergeBlocks (op.elseBlock (), newOp.elseBlock ());
1146
1160
1147
1161
rewriter.setInsertionPointAfter (op);
1148
1162
SmallVector<Value> loaded;
1149
1163
for (Value alloc : allocated) {
1150
- loaded.push_back (rewriter.create <memref::LoadOp>(op.getLoc (), alloc, ValueRange ()));
1164
+ loaded.push_back (
1165
+ rewriter.create <memref::LoadOp>(op.getLoc (), alloc, ValueRange ()));
1151
1166
}
1152
1167
rewriter.replaceOp (op, loaded);
1153
1168
return success ();
@@ -1157,16 +1172,19 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
1157
1172
static void storeValues (Location loc, ValueRange values, ValueRange pointers,
1158
1173
PatternRewriter &rewriter) {
1159
1174
for (auto pair : llvm::zip (values, pointers)) {
1160
- rewriter.create <memref::StoreOp>(loc, std::get<0 >(pair), std::get<1 >(pair), ValueRange ());
1175
+ rewriter.create <memref::StoreOp>(loc, std::get<0 >(pair), std::get<1 >(pair),
1176
+ ValueRange ());
1161
1177
}
1162
1178
}
1163
1179
1164
- static void allocaValues (Location loc, ValueRange values, PatternRewriter &rewriter,
1180
+ static void allocaValues (Location loc, ValueRange values,
1181
+ PatternRewriter &rewriter,
1165
1182
SmallVector<Value> &allocated) {
1166
1183
allocated.reserve (values.size ());
1167
1184
for (Value value : values) {
1168
1185
Value alloc = rewriter.create <memref::AllocaOp>(
1169
- loc, MemRefType::get (ArrayRef<int64_t >(), value.getType ()), ValueRange ());
1186
+ loc, MemRefType::get (ArrayRef<int64_t >(), value.getType ()),
1187
+ ValueRange ());
1170
1188
allocated.push_back (alloc);
1171
1189
}
1172
1190
}
@@ -1184,8 +1202,7 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
1184
1202
// Value stackPtr = rewriter.create<LLVM::StackSaveOp>(
1185
1203
// op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)));
1186
1204
SmallVector<Value> beforeAllocated, afterAllocated;
1187
- allocaValues (op.getLoc (), op.getOperands (), rewriter,
1188
- beforeAllocated);
1205
+ allocaValues (op.getLoc (), op.getOperands (), rewriter, beforeAllocated);
1189
1206
storeValues (op.getLoc (), op.getOperands (), beforeAllocated, rewriter);
1190
1207
allocaValues (op.getLoc (), op.getResults (), rewriter, afterAllocated);
1191
1208
@@ -1194,8 +1211,7 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
1194
1211
Block *newBefore =
1195
1212
rewriter.createBlock (&newOp.getBefore (), newOp.getBefore ().begin ());
1196
1213
SmallVector<Value> newBeforeArguments;
1197
- loadValues (op.getLoc (), beforeAllocated, rewriter,
1198
- newBeforeArguments);
1214
+ loadValues (op.getLoc (), beforeAllocated, rewriter, newBeforeArguments);
1199
1215
rewriter.mergeBlocks (&op.getBefore ().front (), newBefore,
1200
1216
newBeforeArguments);
1201
1217
@@ -1240,12 +1256,10 @@ struct CPUifyPass : public SCFCPUifyBase<CPUifyPass> {
1240
1256
OwningRewritePatternList patterns (&getContext ());
1241
1257
patterns
1242
1258
.insert <Reg2MemFor, Reg2MemWhile, Reg2MemIf,
1243
- // ReplaceIfWithFors,
1244
- WrapForWithBarrier, WrapIfWithBarrier,
1245
- WrapWhileWithBarrier,
1246
- InterchangeForPFor, InterchangeForPForLoad,
1247
- InterchangeIfPFor, InterchangeIfPForLoad,
1248
- InterchangeWhilePFor, NormalizeLoop,
1259
+ // ReplaceIfWithFors,
1260
+ WrapForWithBarrier, WrapIfWithBarrier, WrapWhileWithBarrier,
1261
+ InterchangeForPFor, InterchangeForPForLoad, InterchangeIfPFor,
1262
+ InterchangeIfPForLoad, InterchangeWhilePFor, NormalizeLoop,
1249
1263
NormalizeParallel, RotateWhile, DistributeAroundBarrier>(
1250
1264
&getContext ());
1251
1265
GreedyRewriteConfig config;
0 commit comments