Skip to content

Commit c794fd2

Browse files
committed
Bump LLVM
1 parent fe95821 commit c794fd2

File tree

6 files changed

+74
-10
lines changed

6 files changed

+74
-10
lines changed

lib/polygeist/Ops.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,11 +675,73 @@ struct SelectOfSubIndex : public OpRewritePattern<SelectOp> {
675675
}
676676
};
677677

678+
/// Simplify select subindex(x), subindex(y) to subindex(select x, y)
679+
template<typename T>
680+
struct LoadSelect : public OpRewritePattern<T> {
681+
using OpRewritePattern<T>::OpRewritePattern;
682+
683+
static Value ptr(T op);
684+
static MutableOperandRange ptrMutable(T op);
685+
686+
LogicalResult matchAndRewrite(T op,
687+
PatternRewriter &rewriter) const override {
688+
auto mem0 = ptr(op);
689+
SelectOp mem = dyn_cast_or_null<SelectOp>(mem0.getDefiningOp());
690+
if (!mem)
691+
return failure();
692+
693+
Type tys[] = {op.getType()};
694+
auto iop = rewriter.create<scf::IfOp>(mem.getLoc(), tys, mem.getCondition(), /*hasElse*/true);
695+
696+
auto vop = cast<T>(op->clone());
697+
iop.thenBlock()->push_front(vop);
698+
ptrMutable(vop).assign(mem.getTrueValue());
699+
rewriter.setInsertionPointToEnd(iop.thenBlock());
700+
rewriter.create<scf::YieldOp>(op.getLoc(), vop->getResults());
701+
702+
auto eop = cast<T>(op->clone());
703+
iop.elseBlock()->push_front(eop);
704+
ptrMutable(eop).assign(mem.getFalseValue());
705+
rewriter.setInsertionPointToEnd(iop.elseBlock());
706+
rewriter.create<scf::YieldOp>(op.getLoc(), eop->getResults());
707+
708+
rewriter.replaceOp(op, iop.getResults());
709+
return success();
710+
}
711+
};
712+
713+
template<>
714+
Value LoadSelect<memref::LoadOp>::ptr(memref::LoadOp op) {
715+
return op.memref();
716+
}
717+
template<>
718+
MutableOperandRange LoadSelect<memref::LoadOp>::ptrMutable(memref::LoadOp op) {
719+
return op.memrefMutable();
720+
}
721+
template<>
722+
Value LoadSelect<AffineLoadOp>::ptr(AffineLoadOp op) {
723+
return op.memref();
724+
}
725+
template<>
726+
MutableOperandRange LoadSelect<AffineLoadOp>::ptrMutable(AffineLoadOp op) {
727+
return op.memrefMutable();
728+
}
729+
template<>
730+
Value LoadSelect<LLVM::LoadOp>::ptr(LLVM::LoadOp op) {
731+
return op.getAddr();
732+
}
733+
template<>
734+
MutableOperandRange LoadSelect<LLVM::LoadOp>::ptrMutable(LLVM::LoadOp op) {
735+
return op.getAddrMutable();
736+
}
737+
678738
void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
679739
MLIRContext *context) {
680740
results.insert<CastOfSubIndex, SubIndex2, SubToCast, SimplifySubViewUsers,
681741
SimplifySubIndexUsers, SelectOfCast, SelectOfSubIndex,
682-
RedundantDynSubIndex>(context);
742+
RedundantDynSubIndex, LoadSelect<memref::LoadOp>,
743+
LoadSelect<AffineLoadOp>, LoadSelect<LLVM::LoadOp>
744+
>(context);
683745
// Disabled: SubToSubView
684746
}
685747

lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ struct Pointer2MemrefOpLowering
157157
auto result = getStridesAndOffset(op.getType(), strides, offset);
158158
(void)result;
159159
assert(succeeded(result) && "unexpected failure in stride computation");
160-
assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
160+
assert(offset != ShapedType::kDynamicStrideOrOffset &&
161161
"expected static offset");
162162

163163
bool first = true;
@@ -166,7 +166,7 @@ struct Pointer2MemrefOpLowering
166166
first = false;
167167
return false;
168168
}
169-
return MemRefType::isDynamicStrideOrOffset(stride);
169+
return stride == ShapedType::kDynamicStrideOrOffset;
170170
}) && "expected static strides except first element");
171171

172172
descr.setAllocatedPtr(rewriter, loc, ptr);

lib/polygeist/Passes/ParallelLower.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ mlir::LLVM::LLVMFuncOp GetOrCreateMallocFunction(ModuleOp module) {
148148
mlir::OpBuilder builder(module.getContext());
149149
SymbolTableCollection symbolTable;
150150
if (auto fn = dyn_cast_or_null<LLVM::LLVMFuncOp>(
151-
symbolTable.lookupSymbolIn(module, builder.getIdentifier("malloc"))))
151+
symbolTable.lookupSymbolIn(module, builder.getStringAttr("malloc"))))
152152
return fn;
153153
auto ctx = module->getContext();
154154
mlir::Type types[] = {mlir::IntegerType::get(ctx, 64)};
@@ -164,7 +164,7 @@ mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module) {
164164
mlir::OpBuilder builder(module.getContext());
165165
SymbolTableCollection symbolTable;
166166
if (auto fn = dyn_cast_or_null<LLVM::LLVMFuncOp>(
167-
symbolTable.lookupSymbolIn(module, builder.getIdentifier("free"))))
167+
symbolTable.lookupSymbolIn(module, builder.getStringAttr("free"))))
168168
return fn;
169169
auto ctx = module->getContext();
170170
auto llvmFnType = LLVM::LLVMFunctionType::get(

llvm-project

Submodule llvm-project updated 2004 files

tools/mlir-clang/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ set( LLVM_LINK_COMPONENTS
1717
Vectorize
1818
)
1919

20-
add_clang_tool(mlir-clang
20+
add_clang_executable(mlir-clang
2121
mlir-clang.cc
2222
Lib/CGStmt.cc
2323
"${LLVM_SOURCE_DIR}/../clang/tools/driver/cc1_main.cpp"

tools/mlir-clang/Lib/clang-mlir.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,10 +1126,11 @@ ValueCategory MLIRScanner::VisitCXXNewExpr(clang::CXXNewExpr *expr) {
11261126
auto ty = getMLIRType(expr->getType());
11271127

11281128
mlir::Value alloc;
1129+
mlir::Value arrayCons;
11291130
if (auto mt = ty.dyn_cast<mlir::MemRefType>()) {
11301131
auto shape = std::vector<int64_t>(mt.getShape());
11311132
mlir::Value args[1] = {count};
1132-
alloc = builder.create<mlir::memref::AllocOp>(loc, mt, args);
1133+
arrayCons = alloc = builder.create<mlir::memref::AllocOp>(loc, mt, args);
11331134
} else {
11341135
auto i64 = mlir::IntegerType::get(count.getContext(), 64);
11351136
auto typeSize = getTypeSize(expr->getAllocatedType());
@@ -1141,13 +1142,14 @@ ValueCategory MLIRScanner::VisitCXXNewExpr(clang::CXXNewExpr *expr) {
11411142
.create<mlir::LLVM::CallOp>(loc, Glob.GetOrCreateMallocFunction(),
11421143
args)
11431144
->getResult(0));
1145+
arrayCons = builder.create<mlir::LLVM::BitcastOp>(loc, LLVM::LLVMArrayType::get(ty, 0), alloc);
11441146
}
11451147
assert(alloc);
11461148

11471149
if (expr->getConstructExpr()) {
11481150
VisitConstructCommon(
11491151
const_cast<CXXConstructExpr *>(expr->getConstructExpr()),
1150-
/*name*/ nullptr, /*memtype*/ 0, alloc, count);
1152+
/*name*/ nullptr, /*memtype*/ 0, arrayCons, count);
11511153
}
11521154
return ValueCategory(alloc, /*isRefererence*/ false);
11531155
}
@@ -1305,7 +1307,7 @@ ValueCategory MLIRScanner::VisitConstructCommon(clang::CXXConstructExpr *cons,
13051307

13061308
builder.setInsertionPointToStart(&forOp.getLoopBody().front());
13071309
assert(obj.isReference);
1308-
obj.isReference = false;
1310+
obj = CommonArrayToPointer(obj);
13091311
obj = CommonArrayLookup(obj, forOp.getInductionVar(),
13101312
/*isImplicitRef*/ false, /*removeIndex*/ false);
13111313
assert(obj.isReference);

0 commit comments

Comments
 (0)