Skip to content

Commit aa4972f

Browse files
committed
Cleanup formatting
1 parent 4c60713 commit aa4972f

File tree

4 files changed

+91
-81
lines changed

4 files changed

+91
-81
lines changed

lib/polygeist/Ops.cpp

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,7 @@ struct SelectOfSubIndex : public OpRewritePattern<SelectOp> {
676676
};
677677

678678
/// Simplify select subindex(x), subindex(y) to subindex(select x, y)
679-
template<typename T>
680-
struct LoadSelect : public OpRewritePattern<T> {
679+
template <typename T> struct LoadSelect : public OpRewritePattern<T> {
681680
using OpRewritePattern<T>::OpRewritePattern;
682681

683682
static Value ptr(T op);
@@ -691,14 +690,15 @@ struct LoadSelect : public OpRewritePattern<T> {
691690
return failure();
692691

693692
Type tys[] = {op.getType()};
694-
auto iop = rewriter.create<scf::IfOp>(mem.getLoc(), tys, mem.getCondition(), /*hasElse*/true);
693+
auto iop = rewriter.create<scf::IfOp>(mem.getLoc(), tys, mem.getCondition(),
694+
/*hasElse*/ true);
695695

696696
auto vop = cast<T>(op->clone());
697697
iop.thenBlock()->push_front(vop);
698698
ptrMutable(vop).assign(mem.getTrueValue());
699699
rewriter.setInsertionPointToEnd(iop.thenBlock());
700700
rewriter.create<scf::YieldOp>(op.getLoc(), vop->getResults());
701-
701+
702702
auto eop = cast<T>(op->clone());
703703
iop.elseBlock()->push_front(eop);
704704
ptrMutable(eop).assign(mem.getFalseValue());
@@ -710,38 +710,34 @@ struct LoadSelect : public OpRewritePattern<T> {
710710
}
711711
};
712712

713-
template<>
714-
Value LoadSelect<memref::LoadOp>::ptr(memref::LoadOp op) {
715-
return op.memref();
716-
}
717-
template<>
718-
MutableOperandRange LoadSelect<memref::LoadOp>::ptrMutable(memref::LoadOp op) {
719-
return op.memrefMutable();
720-
}
721-
template<>
722-
Value LoadSelect<AffineLoadOp>::ptr(AffineLoadOp op) {
723-
return op.memref();
724-
}
725-
template<>
726-
MutableOperandRange LoadSelect<AffineLoadOp>::ptrMutable(AffineLoadOp op) {
727-
return op.memrefMutable();
728-
}
729-
template<>
730-
Value LoadSelect<LLVM::LoadOp>::ptr(LLVM::LoadOp op) {
731-
return op.getAddr();
732-
}
733-
template<>
734-
MutableOperandRange LoadSelect<LLVM::LoadOp>::ptrMutable(LLVM::LoadOp op) {
735-
return op.getAddrMutable();
736-
}
713+
template <> Value LoadSelect<memref::LoadOp>::ptr(memref::LoadOp op) {
714+
return op.memref();
715+
}
716+
template <>
717+
MutableOperandRange LoadSelect<memref::LoadOp>::ptrMutable(memref::LoadOp op) {
718+
return op.memrefMutable();
719+
}
720+
template <> Value LoadSelect<AffineLoadOp>::ptr(AffineLoadOp op) {
721+
return op.memref();
722+
}
723+
template <>
724+
MutableOperandRange LoadSelect<AffineLoadOp>::ptrMutable(AffineLoadOp op) {
725+
return op.memrefMutable();
726+
}
727+
template <> Value LoadSelect<LLVM::LoadOp>::ptr(LLVM::LoadOp op) {
728+
return op.getAddr();
729+
}
730+
template <>
731+
MutableOperandRange LoadSelect<LLVM::LoadOp>::ptrMutable(LLVM::LoadOp op) {
732+
return op.getAddrMutable();
733+
}
737734

738735
void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
739736
MLIRContext *context) {
740737
results.insert<CastOfSubIndex, SubIndex2, SubToCast, SimplifySubViewUsers,
741738
SimplifySubIndexUsers, SelectOfCast, SelectOfSubIndex,
742739
RedundantDynSubIndex, LoadSelect<memref::LoadOp>,
743-
LoadSelect<AffineLoadOp>, LoadSelect<LLVM::LoadOp>
744-
>(context);
740+
LoadSelect<AffineLoadOp>, LoadSelect<LLVM::LoadOp>>(context);
745741
// Disabled: SubToSubView
746742
}
747743

lib/polygeist/Passes/CanonicalizeFor.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,8 +1498,7 @@ struct ReturnSq : public OpRewritePattern<ReturnOp> {
14981498
void CanonicalizeFor::runOnFunction() {
14991499
mlir::RewritePatternSet rpl(getFunction().getContext());
15001500
rpl.add<PropagateInLoopBody, DetectTrivialIndVarInArgs,
1501-
ForOpInductionReplacement,
1502-
RemoveUnusedArgs, MoveWhileToFor,
1501+
ForOpInductionReplacement, RemoveUnusedArgs, MoveWhileToFor,
15031502

15041503
MoveWhileDown, MoveWhileDown2
15051504

lib/polygeist/Passes/ParallelLoopDistribute.cpp

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,13 @@ struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
238238

239239
Value difference = rewriter.create<SubIOp>(op.getLoc(), op.getUpperBound(),
240240
op.getLowerBound());
241-
Value tripCount = rewriter.create<AddIOp>(op.getLoc(), rewriter.create<DivUIOp>(op.getLoc(),
242-
rewriter.create<SubIOp>(op.getLoc(), difference, one), op.getStep()), one);
243-
// rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
241+
Value tripCount = rewriter.create<AddIOp>(
242+
op.getLoc(),
243+
rewriter.create<DivUIOp>(
244+
op.getLoc(), rewriter.create<SubIOp>(op.getLoc(), difference, one),
245+
op.getStep()),
246+
one);
247+
// rewriter.create<CeilDivSIOp>(op.getLoc(), difference, op.getStep());
244248
auto newForOp =
245249
rewriter.create<scf::ForOp>(op.getLoc(), zero, tripCount, one);
246250
rewriter.setInsertionPointToStart(newForOp.getBody());
@@ -455,34 +459,38 @@ static void moveBodies(PatternRewriter &rewriter, scf::ParallelOp op,
455459
scf::IfOp ifOp, scf::IfOp newIf) {
456460
rewriter.startRootUpdate(op);
457461
{
458-
OpBuilder::InsertionGuard guard(rewriter);
459-
rewriter.setInsertionPointToStart(newIf.thenBlock());
460-
auto newParallel = rewriter.create<scf::ParallelOp>(
461-
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
462+
OpBuilder::InsertionGuard guard(rewriter);
463+
rewriter.setInsertionPointToStart(newIf.thenBlock());
464+
auto newParallel = rewriter.create<scf::ParallelOp>(
465+
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
462466

463-
for (auto tup : llvm::zip(newParallel.getInductionVars(), op.getInductionVars())) {
467+
for (auto tup :
468+
llvm::zip(newParallel.getInductionVars(), op.getInductionVars())) {
464469
std::get<1>(tup).replaceUsesWithIf(std::get<0>(tup), [&](OpOperand &op) {
465-
return ifOp.getThenRegion().isAncestor(op.getOwner()->getParentRegion());
470+
return ifOp.getThenRegion().isAncestor(
471+
op.getOwner()->getParentRegion());
466472
});
467-
}
473+
}
468474

469-
rewriter.mergeBlockBefore(ifOp.thenBlock(), &newParallel.getBody()->back());
470-
rewriter.eraseOp(&newParallel.getBody()->back());
475+
rewriter.mergeBlockBefore(ifOp.thenBlock(), &newParallel.getBody()->back());
476+
rewriter.eraseOp(&newParallel.getBody()->back());
471477
}
472478

473479
if (ifOp.getElseRegion().getBlocks().size() > 0) {
474-
OpBuilder::InsertionGuard guard(rewriter);
475-
rewriter.setInsertionPointToStart(newIf.elseBlock());
476-
auto newParallel = rewriter.create<scf::ParallelOp>(
477-
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
480+
OpBuilder::InsertionGuard guard(rewriter);
481+
rewriter.setInsertionPointToStart(newIf.elseBlock());
482+
auto newParallel = rewriter.create<scf::ParallelOp>(
483+
op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep());
478484

479-
for (auto tup : llvm::zip(newParallel.getInductionVars(), op.getInductionVars())) {
485+
for (auto tup :
486+
llvm::zip(newParallel.getInductionVars(), op.getInductionVars())) {
480487
std::get<1>(tup).replaceUsesWithIf(std::get<0>(tup), [&](OpOperand &op) {
481-
return ifOp.getElseRegion().isAncestor(op.getOwner()->getParentRegion());
488+
return ifOp.getElseRegion().isAncestor(
489+
op.getOwner()->getParentRegion());
482490
});
483-
}
484-
rewriter.mergeBlockBefore(ifOp.elseBlock(), &newParallel.getBody()->back());
485-
rewriter.eraseOp(&newParallel.getBody()->back());
491+
}
492+
rewriter.mergeBlockBefore(ifOp.elseBlock(), &newParallel.getBody()->back());
493+
rewriter.eraseOp(&newParallel.getBody()->back());
486494
}
487495

488496
rewriter.eraseOp(ifOp);
@@ -518,8 +526,9 @@ struct InterchangeIfPFor : public OpRewritePattern<scf::ParallelOp> {
518526
return failure();
519527
}
520528

521-
auto newIf =
522-
rewriter.create<scf::IfOp>(ifOp.getLoc(), TypeRange(), ifOp.getCondition(), ifOp.getElseRegion().getBlocks().size() > 0);
529+
auto newIf = rewriter.create<scf::IfOp>(
530+
ifOp.getLoc(), TypeRange(), ifOp.getCondition(),
531+
ifOp.getElseRegion().getBlocks().size() > 0);
523532
moveBodies(rewriter, op, ifOp, newIf);
524533
return success();
525534
}
@@ -563,9 +572,10 @@ struct InterchangeIfPForLoad : public OpRewritePattern<scf::ParallelOp> {
563572
Value condition = rewriter.create<memref::LoadOp>(
564573
loadOp.getLoc(), loadOp.getMemRef(),
565574
SmallVector<Value>(loadOp.getMemRefType().getRank(), zero));
566-
575+
567576
auto newIf =
568-
rewriter.create<scf::IfOp>(ifOp.getLoc(), TypeRange(), condition, ifOp.getElseRegion().getBlocks().size() > 0);
577+
rewriter.create<scf::IfOp>(ifOp.getLoc(), TypeRange(), condition,
578+
ifOp.getElseRegion().getBlocks().size() > 0);
569579
moveBodies(rewriter, op, ifOp, newIf);
570580
return success();
571581
}
@@ -1072,9 +1082,11 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
10721082
allocated.reserve(op.getNumIterOperands());
10731083
for (Value operand : op.getIterOperands()) {
10741084
Value alloc = rewriter.create<memref::AllocaOp>(
1075-
op.getLoc(), MemRefType::get(ArrayRef<int64_t>(), operand.getType()), ValueRange());
1085+
op.getLoc(), MemRefType::get(ArrayRef<int64_t>(), operand.getType()),
1086+
ValueRange());
10761087
allocated.push_back(alloc);
1077-
rewriter.create<memref::StoreOp>(op.getLoc(), operand, alloc, ValueRange());
1088+
rewriter.create<memref::StoreOp>(op.getLoc(), operand, alloc,
1089+
ValueRange());
10781090
}
10791091

10801092
auto newOp = rewriter.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
@@ -1098,7 +1110,8 @@ struct Reg2MemFor : public OpRewritePattern<scf::ForOp> {
10981110
rewriter.setInsertionPointAfter(op);
10991111
SmallVector<Value> loaded;
11001112
for (Value alloc : allocated) {
1101-
loaded.push_back(rewriter.create<memref::LoadOp>(op.getLoc(), alloc, ValueRange()));
1113+
loaded.push_back(
1114+
rewriter.create<memref::LoadOp>(op.getLoc(), alloc, ValueRange()));
11021115
}
11031116
rewriter.replaceOp(op, loaded);
11041117
return success();
@@ -1112,25 +1125,26 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
11121125
PatternRewriter &rewriter) const override {
11131126
if (!op.getResults().size() || !hasNestedBarrier(op))
11141127
return failure();
1115-
11161128

11171129
SmallVector<Value> allocated;
11181130
allocated.reserve(op.getNumResults());
11191131
for (Type opType : op.getResultTypes()) {
11201132
Value alloc = rewriter.create<memref::AllocaOp>(
1121-
op.getLoc(), MemRefType::get(ArrayRef<int64_t>(), opType), ValueRange());
1133+
op.getLoc(), MemRefType::get(ArrayRef<int64_t>(), opType),
1134+
ValueRange());
11221135
allocated.push_back(alloc);
11231136
}
1124-
1125-
auto newOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(), op.getCondition(), true);
1137+
1138+
auto newOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
1139+
op.getCondition(), true);
11261140

11271141
rewriter.setInsertionPoint(op.thenYield());
11281142
for (auto en : llvm::enumerate(op.thenYield().getOperands())) {
11291143
rewriter.create<memref::StoreOp>(op.getLoc(), en.value(),
11301144
allocated[en.index()], ValueRange());
11311145
}
11321146
op.thenYield()->setOperands(ValueRange());
1133-
1147+
11341148
rewriter.setInsertionPoint(op.elseYield());
11351149
for (auto en : llvm::enumerate(op.elseYield().getOperands())) {
11361150
rewriter.create<memref::StoreOp>(op.getLoc(), en.value(),
@@ -1140,14 +1154,15 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
11401154

11411155
rewriter.eraseOp(&newOp.thenBlock()->back());
11421156
rewriter.mergeBlocks(op.thenBlock(), newOp.thenBlock());
1143-
1157+
11441158
rewriter.eraseOp(&newOp.elseBlock()->back());
11451159
rewriter.mergeBlocks(op.elseBlock(), newOp.elseBlock());
11461160

11471161
rewriter.setInsertionPointAfter(op);
11481162
SmallVector<Value> loaded;
11491163
for (Value alloc : allocated) {
1150-
loaded.push_back(rewriter.create<memref::LoadOp>(op.getLoc(), alloc, ValueRange()));
1164+
loaded.push_back(
1165+
rewriter.create<memref::LoadOp>(op.getLoc(), alloc, ValueRange()));
11511166
}
11521167
rewriter.replaceOp(op, loaded);
11531168
return success();
@@ -1157,16 +1172,19 @@ struct Reg2MemIf : public OpRewritePattern<scf::IfOp> {
11571172
static void storeValues(Location loc, ValueRange values, ValueRange pointers,
11581173
PatternRewriter &rewriter) {
11591174
for (auto pair : llvm::zip(values, pointers)) {
1160-
rewriter.create<memref::StoreOp>(loc, std::get<0>(pair), std::get<1>(pair), ValueRange());
1175+
rewriter.create<memref::StoreOp>(loc, std::get<0>(pair), std::get<1>(pair),
1176+
ValueRange());
11611177
}
11621178
}
11631179

1164-
static void allocaValues(Location loc, ValueRange values, PatternRewriter &rewriter,
1180+
static void allocaValues(Location loc, ValueRange values,
1181+
PatternRewriter &rewriter,
11651182
SmallVector<Value> &allocated) {
11661183
allocated.reserve(values.size());
11671184
for (Value value : values) {
11681185
Value alloc = rewriter.create<memref::AllocaOp>(
1169-
loc, MemRefType::get(ArrayRef<int64_t>(), value.getType()), ValueRange());
1186+
loc, MemRefType::get(ArrayRef<int64_t>(), value.getType()),
1187+
ValueRange());
11701188
allocated.push_back(alloc);
11711189
}
11721190
}
@@ -1184,8 +1202,7 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
11841202
// Value stackPtr = rewriter.create<LLVM::StackSaveOp>(
11851203
// op.getLoc(), LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)));
11861204
SmallVector<Value> beforeAllocated, afterAllocated;
1187-
allocaValues(op.getLoc(), op.getOperands(), rewriter,
1188-
beforeAllocated);
1205+
allocaValues(op.getLoc(), op.getOperands(), rewriter, beforeAllocated);
11891206
storeValues(op.getLoc(), op.getOperands(), beforeAllocated, rewriter);
11901207
allocaValues(op.getLoc(), op.getResults(), rewriter, afterAllocated);
11911208

@@ -1194,8 +1211,7 @@ struct Reg2MemWhile : public OpRewritePattern<scf::WhileOp> {
11941211
Block *newBefore =
11951212
rewriter.createBlock(&newOp.getBefore(), newOp.getBefore().begin());
11961213
SmallVector<Value> newBeforeArguments;
1197-
loadValues(op.getLoc(), beforeAllocated, rewriter,
1198-
newBeforeArguments);
1214+
loadValues(op.getLoc(), beforeAllocated, rewriter, newBeforeArguments);
11991215
rewriter.mergeBlocks(&op.getBefore().front(), newBefore,
12001216
newBeforeArguments);
12011217

@@ -1240,12 +1256,10 @@ struct CPUifyPass : public SCFCPUifyBase<CPUifyPass> {
12401256
OwningRewritePatternList patterns(&getContext());
12411257
patterns
12421258
.insert<Reg2MemFor, Reg2MemWhile, Reg2MemIf,
1243-
//ReplaceIfWithFors,
1244-
WrapForWithBarrier, WrapIfWithBarrier,
1245-
WrapWhileWithBarrier,
1246-
InterchangeForPFor, InterchangeForPForLoad,
1247-
InterchangeIfPFor, InterchangeIfPForLoad,
1248-
InterchangeWhilePFor, NormalizeLoop,
1259+
// ReplaceIfWithFors,
1260+
WrapForWithBarrier, WrapIfWithBarrier, WrapWhileWithBarrier,
1261+
InterchangeForPFor, InterchangeForPForLoad, InterchangeIfPFor,
1262+
InterchangeIfPForLoad, InterchangeWhilePFor, NormalizeLoop,
12491263
NormalizeParallel, RotateWhile, DistributeAroundBarrier>(
12501264
&getContext());
12511265
GreedyRewriteConfig config;

tools/mlir-clang/Lib/clang-mlir.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,8 @@ ValueCategory MLIRScanner::VisitCXXNewExpr(clang::CXXNewExpr *expr) {
11421142
.create<mlir::LLVM::CallOp>(loc, Glob.GetOrCreateMallocFunction(),
11431143
args)
11441144
->getResult(0));
1145-
arrayCons = builder.create<mlir::LLVM::BitcastOp>(loc, LLVM::LLVMArrayType::get(ty, 0), alloc);
1145+
arrayCons = builder.create<mlir::LLVM::BitcastOp>(
1146+
loc, LLVM::LLVMArrayType::get(ty, 0), alloc);
11461147
}
11471148
assert(alloc);
11481149

0 commit comments

Comments
 (0)