From 0b97da9779400937a4d0ade034ae6147077008cc Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 6 Jun 2024 08:41:23 +0000 Subject: [PATCH 01/77] Unfinished changes with prototype function --- lib/polygeist/Passes/RaiseToLinalg.cpp | 35 +- test/polygeist-opt/linalgraise.mlir | 822 +++++++++++++------------ 2 files changed, 446 insertions(+), 411 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 254d3a11881b..32af67d0b397 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -23,6 +23,7 @@ using namespace mlir; using namespace mlir::arith; using namespace polygeist; using namespace affine; +using namespace linalg; namespace { struct RaiseAffineToLinalg : public AffineRaiseToLinalgBase { @@ -111,6 +112,7 @@ bool isLinearInIndex(AffineMap map, size_t idx) { std::pair remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, Value val, Value idx, Value idx_size, int loopLowerBound, int loopStepSize, mlir::OperandRange vals) { // First we need to remove any dependence on the loop index from the affine map SmallVector vals_without_idx; + //This tracks the index corresponding to the for loop if present in load/store operands else it's -1 ssize_t dim_idx = -1; //To check if induction variable of for loop in an operand of this op (load/store) for (auto &&[i, v] : llvm::enumerate(vals)) { @@ -207,6 +209,7 @@ struct AffineForOpRaising : public OpRewritePattern { SmallVector, AffineLoadOp>> loads; SmallVector, AffineStoreOp>> stores; + SmallVector, GenericOp>> linalgGenerics; // TODO Also collect all the linalg generics! // Check that the only operations within the region are either: @@ -220,7 +223,7 @@ struct AffineForOpRaising : public OpRewritePattern { if (isa(op)) { return WalkResult::advance(); } - if (isa(op)) { + if (isa(op) || isa(op)) { Operation *cur = op->getParentOp(); std::vector conditions; while (cur != loop) { @@ -232,7 +235,10 @@ struct AffineForOpRaising : public OpRewritePattern { conditions.emplace_back(ifTrue, ifstmt); cur = ifstmt->getParentOp(); } - if (auto load = dyn_cast(op)) { + if (auto linalgGeneric = dyn_cast(op)) { + linalgGenerics.emplace_back(conditions, linalgGeneric); + } + else if (auto load = dyn_cast(op)) { loads.emplace_back(conditions, load); } else { auto store = cast(op); @@ -240,6 +246,7 @@ struct AffineForOpRaising : public OpRewritePattern { } return WalkResult::advance(); } + //IsReadNone takes care of apply and subview too? if (isReadNone(op)) { return WalkResult::advance(); } @@ -261,6 +268,9 @@ struct AffineForOpRaising : public OpRewritePattern { if (load.getMemref() == store.getMemref() && load.getAffineMap() == store.getAffineMap() && load.getIndices() == store.getIndices() && DI.dominates((Operation*)load,(Operation*)store)) { + //Example case where load does not dominate stores - if the load was conditional. + //Or, store followed by load? + //Q. Can't we still overlook the aliasing? stores_map[load] = store; continue; } @@ -331,6 +341,24 @@ struct AffineForOpRaising : public OpRewritePattern { //Value loopSize = rewriter.create(loop.getLoc(), loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), *ub, *lb); + for (auto &&[conds, lg] : linalgGenerics) { + for(auto &x : lg.args?) + //Is this needed? + if (conds.size() != 0) return failure(); + + getLinalgArgMap(x, lgMap, lgOperands, lgMemref); + bool legal = true; + + auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), + loopSize, lbConst.getValue(), step, lgOperands); + + if (!legal) return failure(); + + //TODO: need to mergre previous indexing maps and new affine maps + affineMaps.push_back(newAffineMap); + inputs.push_back(newMemref); + } + // current spec is going to be indexed off of the loop var in isolation for (auto &&[conds, load] : loads) { // Only support unconditional loads for the moment @@ -372,7 +400,10 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO Push all of the outputs to the linalg generics // TODO presently if linalg generic exists, assert there are no load/stores + assert((linalgGenerics.size() > 0) ? ((loads.size() == 0 ) && (stores.size() == 0)) : 1); // TODO assert only zero or one linalg generic exists + assert(linalgGenerics.size() == 1 || linalgGenerics.size() == 0); + SmallVector iteratorTypes; // TODO if linalg generic exists, make this iterator type prepend to the existing iterators iteratorTypes.push_back((stores_map.size() == 0) ? utils::IteratorType::parallel : utils::IteratorType::reduction); diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index e0ceffa1849c..27b0a843dddb 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -1,380 +1,409 @@ -// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s +//// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s +// +//module { +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// scf.if %12 { +// affine.for %arg4 = 0 to %17 { +// %ld = affine.load %18[%arg4] : memref +// affine.store %ld, %19[%arg4] : memref +// } +// } +// return +// } +// +// +// func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// scf.if %12 { +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[3 * %arg4] : memref +// %ld2 = affine.load %18[0] : memref +// %fadd = arith.addf %ld, %ld2 : f32 +// affine.store %fadd, %19[%arg4 + 17] : memref +// } +// } +// return +// } +// +//} +// +//// CHECK: #map = affine_map<(d0) -> (d0)> +//// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { +//// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index +//// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index +//// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index +//// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index +//// CHECK-NEXT: scf.if %[[arg0]] { +//// TODO note that presently we do not ensure that the memrefs are sliced to the right size as the space requires +//// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg2 : memref) outs(%alloca : memref) { +//// CHECK-NEXT: ^bb0(%in: f32, %out: f32): +//// CHECK-NEXT: linalg.yield %in : f32 +//// CHECK-NEXT: } +//// CHECK-NEXT: } +//// CHECK-NEXT: } +// +////constant-access +//module @constant_access{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %ci324 = arith.constant 4.0 : f32 +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ci324 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////constant-mem-access +//module @constant_mem_access{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 4 to 17 step 2 { +// %ld = affine.load %18[3*%arg4] : memref +// %ld2 = affine.load %18[%c4] : memref +// %mul = arith.mulf %ld, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////without-if +//module @no_if{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// affine.store %ld, %19[%arg4] : memref +// } +// return +// } +//} +// +////arith.mul +//module @arith_mul{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ld : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////arith.add +//module @arith_add{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg4] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////Conditional arith +//module @cond_arith{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %if = scf.if %12 -> f32 { +// %mul = arith.mulf %ld, %ld : f32 +// scf.yield %mul : f32 +// } else { +// scf.yield %ld : f32 +// } +// affine.store %if, %19[%arg4] : memref +// } +// return +// } +//} +// +////reduction +//module @reduction{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// %sum_0 = arith.constant 0.0 : f32 +// %red = affine.for %arg4 = 0 to 17 step 1 iter_args(%sum_iter = %sum_0) -> f32 { +// %ld1 = affine.load %18[%arg4] : memref +// %sum_next = arith.addf %sum_iter, %ld1 : f32 +// affine.yield %sum_next : f32 +// } +// affine.store %red, %19[0] : memref +// return +// } +//} +// +////Conditional store-1 +//module @cond_store_1 { +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ld : f32 +// scf.if %12 { +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////Conditional store-2 +//module @cond_store_2{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// scf.if %12 { +// %mul = arith.mulf %ld, %ld : f32 +// affine.store %mul, %19[%arg4] : memref +// } else { +// affine.store %ld, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////Parallel for +//module @parallel_for{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ld : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg4] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////Fors inside for +//module @for_within_for{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////Parallel fors inside for +//module @parallel_fors_inside_for { +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 17 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////matrix-mul iter arg +//module @matmul_1 { +// memref.global @out : memref<32x8xi32> = uninitialized +// memref.global @im2 : memref<8x8xi32> = uninitialized +// memref.global @im1 : memref<32x8xi32> = uninitialized +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im1 : memref<32x8xi32> +// %1 = memref.get_global @im2 : memref<8x8xi32> +// %2 = memref.get_global @out : memref<32x8xi32> +// affine.for %arg0 = 0 to 32 { +// affine.for %arg1 = 0 to 8 { +// %3 = affine.for %arg2 = 0 to 8 iter_args(%arg3 = %c0_i32) -> (i32) { +// %4 = affine.load %0[%arg0, %arg2] : memref<32x8xi32> +// %5 = affine.load %1[%arg2, %arg1] : memref<8x8xi32> +// %6 = arith.muli %4, %5 : i32 +// %7 = arith.addi %arg3, %6 : i32 +// affine.yield %7 : i32 +// } +// affine.store %3, %2[%arg0, %arg1] : memref<32x8xi32> +// } +// } +// return %c0_i32 : i32 +// } +//} +// +////matrix-mul alias issue +//module @matmul_2 { +// memref.global @out : memref<128x32xi32> = uninitialized +// memref.global @im2 : memref<64x32xi32> = uninitialized +// memref.global @im1 : memref<128x64xi32> = uninitialized +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im1 : memref<128x64xi32> +// %1 = memref.get_global @im2 : memref<64x32xi32> +// %2 = memref.get_global @out : memref<128x32xi32> +// affine.for %arg0 = 0 to 128 { +// affine.for %arg1 = 0 to 32 { +// affine.for %arg2 = 0 to 64 { +// %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> +// %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> +// %5 = arith.muli %3, %4 : i32 +// %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> +// %7 = arith.addi %6, %5 : i32 +// affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> +// } +// } +// } +// return %c0_i32 : i32 +// } +//} +// +////conv (with inner loop accumulate) +////How to deal with IR in outer loops as well? +//module @conv_1{ +// memref.global @out : memref<512x64xi32> = uninitialized +// memref.global @filter : memref<4x4xi32> = uninitialized +// memref.global @im : memref<515x67xi32> = uninitialized +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im : memref<515x67xi32> +// %1 = memref.get_global @filter : memref<4x4xi32> +// %2 = memref.get_global @out : memref<512x64xi32> +// affine.for %arg0 = 0 to 512 { +// affine.for %arg1 = 0 to 64 { +// %3 = affine.for %arg2 = 0 to 4 iter_args(%arg3 = %c0_i32) -> (i32) { +// %4 = affine.for %arg4 = 0 to 4 iter_args(%arg5 = %arg3) -> (i32) { +// %5 = affine.load %0[%arg0 + %arg2, %arg1 + %arg4] : memref<515x67xi32> +// %6 = affine.load %1[%arg2, %arg4] : memref<4x4xi32> +// %7 = arith.muli %5, %6 : i32 +// %8 = arith.addi %arg5, %7 : i32 +// affine.yield %8 : i32 +// } +// affine.yield %4 : i32 +// } +// affine.store %3, %2[%arg0, %arg1] : memref<512x64xi32> +// } +// } +// return %c0_i32 : i32 +// } +//} -module { - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - scf.if %12 { - affine.for %arg4 = 0 to %17 { - %ld = affine.load %18[%arg4] : memref - affine.store %ld, %19[%arg4] : memref - } - } - return - } - - - func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - scf.if %12 { - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[3 * %arg4] : memref - %ld2 = affine.load %18[0] : memref - %fadd = arith.addf %ld, %ld2 : f32 - affine.store %fadd, %19[%arg4 + 17] : memref - } - } - return - } - -} - -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { -// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index -// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index -// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index -// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index -// CHECK-NEXT: scf.if %[[arg0]] { -// TODO note that presently we do not ensure that the memrefs are sliced to the right size as the space requires -// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg2 : memref) outs(%alloca : memref) { -// CHECK-NEXT: ^bb0(%in: f32, %out: f32): -// CHECK-NEXT: linalg.yield %in : f32 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } - -//constant-access -module @constant_access{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %ci324 = arith.constant 4.0 : f32 - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ci324 : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//constant-mem-access -module @constant_mem_access{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 4 to 17 step 2 { - %ld = affine.load %18[3*%arg4] : memref - %ld2 = affine.load %18[%c4] : memref - %mul = arith.mulf %ld, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//without-if -module @no_if{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - affine.store %ld, %19[%arg4] : memref - } - return - } -} - -//arith.mul -module @arith_mul{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ld : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//arith.add -module @arith_add{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg4] : memref - %ld2 = affine.load %20[%arg4] : memref - %add = arith.addf %ld1, %ld2 : f32 - %mul = arith.mulf %add, %add : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//Conditional arith -module @cond_arith{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %if = scf.if %12 -> f32 { - %mul = arith.mulf %ld, %ld : f32 - scf.yield %mul : f32 - } else { - scf.yield %ld : f32 - } - affine.store %if, %19[%arg4] : memref - } - return - } -} - -//reduction -module @reduction{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - %sum_0 = arith.constant 0.0 : f32 - %red = affine.for %arg4 = 0 to 17 step 1 iter_args(%sum_iter = %sum_0) -> f32 { - %ld1 = affine.load %18[%arg4] : memref - %sum_next = arith.addf %sum_iter, %ld1 : f32 - affine.yield %sum_next : f32 - } - affine.store %red, %19[0] : memref - return - } -} - -//Conditional store-1 -module @cond_store_1 { - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ld : f32 - scf.if %12 { - affine.store %mul, %19[%arg4] : memref - } - } - return - } -} - -//Conditional store-2 -module @cond_store_2{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - scf.if %12 { - %mul = arith.mulf %ld, %ld : f32 - affine.store %mul, %19[%arg4] : memref - } else { - affine.store %ld, %19[%arg4] : memref - } - } - return - } -} - -//Parallel for -module @parallel_for{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ld : f32 - affine.store %mul, %19[%arg4] : memref - } - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg4] : memref - %ld2 = affine.load %20[%arg4] : memref - %add = arith.addf %ld1, %ld2 : f32 - %mul = arith.mulf %add, %add : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//Fors inside for -module @for_within_for{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - } - return - } -} - -//Parallel fors inside for -module @parallel_fors_inside_for { - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 17 { - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %add = arith.addf %ld1, %ld2 : f32 - %mul = arith.mulf %add, %add : f32 - affine.store %mul, %19[%arg4] : memref - } - } - return - } -} - -//matrix-mul iter arg -module @matmul_1 { - memref.global @out : memref<32x8xi32> = uninitialized - memref.global @im2 : memref<8x8xi32> = uninitialized - memref.global @im1 : memref<32x8xi32> = uninitialized - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im1 : memref<32x8xi32> - %1 = memref.get_global @im2 : memref<8x8xi32> - %2 = memref.get_global @out : memref<32x8xi32> - affine.for %arg0 = 0 to 32 { - affine.for %arg1 = 0 to 8 { - %3 = affine.for %arg2 = 0 to 8 iter_args(%arg3 = %c0_i32) -> (i32) { - %4 = affine.load %0[%arg0, %arg2] : memref<32x8xi32> - %5 = affine.load %1[%arg2, %arg1] : memref<8x8xi32> - %6 = arith.muli %4, %5 : i32 - %7 = arith.addi %arg3, %6 : i32 - affine.yield %7 : i32 - } - affine.store %3, %2[%arg0, %arg1] : memref<32x8xi32> - } - } - return %c0_i32 : i32 - } -} - -//matrix-mul alias issue -module @matmul_2 { - memref.global @out : memref<128x32xi32> = uninitialized - memref.global @im2 : memref<64x32xi32> = uninitialized - memref.global @im1 : memref<128x64xi32> = uninitialized - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im1 : memref<128x64xi32> - %1 = memref.get_global @im2 : memref<64x32xi32> - %2 = memref.get_global @out : memref<128x32xi32> - affine.for %arg0 = 0 to 128 { - affine.for %arg1 = 0 to 32 { - affine.for %arg2 = 0 to 64 { - %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> - %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> - %5 = arith.muli %3, %4 : i32 - %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> - %7 = arith.addi %6, %5 : i32 - affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> - } - } - } - return %c0_i32 : i32 - } -} - -//conv (with inner loop accumulate) -//How to deal with IR in outer loops as well? -module @conv_1{ +//conv (direct store) +module @conv_2 { memref.global @out : memref<512x64xi32> = uninitialized memref.global @filter : memref<4x4xi32> = uninitialized memref.global @im : memref<515x67xi32> = uninitialized @@ -383,46 +412,21 @@ module @conv_1{ %0 = memref.get_global @im : memref<515x67xi32> %1 = memref.get_global @filter : memref<4x4xi32> %2 = memref.get_global @out : memref<512x64xi32> - affine.for %arg0 = 0 to 512 { - affine.for %arg1 = 0 to 64 { - %3 = affine.for %arg2 = 0 to 4 iter_args(%arg3 = %c0_i32) -> (i32) { - %4 = affine.for %arg4 = 0 to 4 iter_args(%arg5 = %arg3) -> (i32) { - %5 = affine.load %0[%arg0 + %arg2, %arg1 + %arg4] : memref<515x67xi32> - %6 = affine.load %1[%arg2, %arg4] : memref<4x4xi32> - %7 = arith.muli %5, %6 : i32 - %8 = arith.addi %arg5, %7 : i32 - affine.yield %8 : i32 - } - affine.yield %4 : i32 - } - affine.store %3, %2[%arg0, %arg1] : memref<512x64xi32> - } - } - return %c0_i32 : i32 - } -} - -//conv (direct store) -module @conv_2{ - memref.global @out : memref<512x64xi32> = uninitialized - memref.global @filter : memref<4x4xi32> = uninitialized - memref.global @im : memref<515x67xi32> = uninitialized - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im : memref<515x67xi32> - %1 = memref.get_global @out : memref<512x64xi32> affine.for %arg0 = 0 to 512 { affine.for %arg1 = 0 to 64 { affine.for %arg2 = 0 to 4 { affine.for %arg3 = 0 to 4 { - %2 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> - %3 = affine.load %1[%arg0, %arg1] : memref<512x64xi32> - %4 = arith.addi %3, %2 : i32 - affine.store %4, %1[%arg0, %arg1] : memref<512x64xi32> + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> } } } } return %c0_i32 : i32 } -} \ No newline at end of file +} + \ No newline at end of file From 69ef423830c4130acbef575ac58bd4e5f3bc67d1 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 6 Jun 2024 15:10:24 +0000 Subject: [PATCH 02/77] Loop over linalg.generic's input and output ops --- lib/polygeist/Passes/RaiseToLinalg.cpp | 43 +++++++++++++++++++------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 32af67d0b397..0d65ddd577af 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -342,21 +342,42 @@ struct AffineForOpRaising : public OpRewritePattern { //Value loopSize = rewriter.create(loop.getLoc(), loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), *ub, *lb); for (auto &&[conds, lg] : linalgGenerics) { - for(auto &x : lg.args?) - //Is this needed? - if (conds.size() != 0) return failure(); + // Iterate over input arguments + for (Value input : lg.getInputs()) { + //Is this needed? + if (conds.size() != 0) return failure(); + + //TODO: Implement this + getLinalgArgMap(inout, lgMap, lgOperands, lgMemref); + bool legal = true; + + auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), + loopSize, lbConst.getValue(), step, lgOperands); - getLinalgArgMap(x, lgMap, lgOperands, lgMemref); - bool legal = true; + if (!legal) return failure(); + + //TODO: need to mergre previous indexing maps and new affine maps + affineMaps.push_back(newAffineMap); + inputs.push_back(newMemref); + } + + // Iterate over output arguments + for (Value output : lg.getOutputs()) { + //Is this needed? + if (conds.size() != 0) return failure(); + + getLinalgArgMap(output, lgMap, lgOperands, lgMemref); + bool legal = true; - auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), - loopSize, lbConst.getValue(), step, lgOperands); + auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), + loopSize, lbConst.getValue(), step, lgOperands); - if (!legal) return failure(); + if (!legal) return failure(); - //TODO: need to mergre previous indexing maps and new affine maps - affineMaps.push_back(newAffineMap); - inputs.push_back(newMemref); + //TODO: need to mergre previous indexing maps and new affine maps + affineMaps.push_back(newAffineMap); + inputs.push_back(newMemref); + } } // current spec is going to be indexed off of the loop var in isolation From 7678a05f5b86e8eda32b476914dc4baf457baea4 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 6 Jun 2024 15:31:57 +0000 Subject: [PATCH 03/77] Some comments --- lib/polygeist/Passes/RaiseToLinalg.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 0d65ddd577af..87f8b0da87b3 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -289,6 +289,7 @@ struct AffineForOpRaising : public OpRewritePattern { SmallVector inputs; SmallVector affineMaps; + SmallVector indexingMaps; //if (loop.getStep() != 1) { // return failure(); @@ -342,12 +343,18 @@ struct AffineForOpRaising : public OpRewritePattern { //Value loopSize = rewriter.create(loop.getLoc(), loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), *ub, *lb); for (auto &&[conds, lg] : linalgGenerics) { + + //This captures the indexing map attribute from the linalg.generic being processed + ArrayAttr indexingMapsAttr = lg.getIndexingMaps(); + // Iterate over input arguments for (Value input : lg.getInputs()) { //Is this needed? if (conds.size() != 0) return failure(); //TODO: Implement this + //lgMap comes from offset of memref.subview, + //lgOperands comes from operands of memref.subview getLinalgArgMap(inout, lgMap, lgOperands, lgMemref); bool legal = true; From 0e8809518be31e529051b3b3462ce335b0f14f3b Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 11 Jun 2024 22:58:58 +0000 Subject: [PATCH 04/77] Partial changes from coding session to implement fusion of linalg.generic and for op --- lib/polygeist/Passes/RaiseToLinalg.cpp | 124 ++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 4 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 87f8b0da87b3..1a539299ac02 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -194,6 +194,106 @@ std::pair remap_in_affine_dim(bool &legal, OpBuilder &builder, val is now prevA */ +/* + +f(%memref ) + +%memref = ... + +affine.for { + + + %inp = .. subview %memref [ ... ] + + linalg.generic %inp #map { + + } +} + + +#map2 = #map with the indexing done to %inp + +*/ + + +LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, lgOperands, lgMemref) { + + while (Operation *defOp = input.getDefiningOp()) { + + // If the input is defined outside of the loop, we are finished. + if (!loop->isAncestor(defOp)) continue; + + if (auto SV = dyn_cast(defOp)) { + + // TODO update map with the new indexing from here + + size_t numNewStartDims = 0; + size_t numNewStartSymbols = 0; + for (auto val : SV->getStarts()) { + // Only support constants, symbols, or affine apply as offsets + if (val.getDefinigOp()) { + continue; + } + auto valOp = val.getDefiningOp(); + // Defined outside loop, consider it a symbol [for now] + if (!valOp || !loop->isAncestor(valOp)) continue; + + if(auto index = dyn_cast<>(valOp)) { + + } + + //Q. If we just extract num dims and symbs- + // i. Won't we miss constant values in the affine map? + // ii. How will we know the relation between dims and syms? + // Eg- affine_map<(d0, d1)[s0] -> (d0 + 2 * d1 + s0, d1 - s0)> + //Also we need to check for unique args and only count them in numNewStartDims and Symbols. + if (auto apply = dyn_cast(valOp)) { + numNewStartDims += apply.getAffineMap().getNumDims(); + numNewStartSymbols += apply.getAffineMap().getNumSymbols(); + newExpr = apply.getResults(); + } + + // unsupported index to subview + return failure(); + } + size_t numNewStrideDims = 0; + size_t numNewStrideSymbols = 0; + for (auto val : SV->getStrides()) { + // Only support constants, symbols, or affine apply as offsets + if (val.getDefinigOp()) { + continue; + } + auto valOp = val.getDefiningOp(); + // Defined outside loop, consider it a symbol [for now] + if (!valOp || loop->isAncestor(defOp)) continue; + + if (auto apply = dyn_cast(val)) { + numNewStrideDims += apply.getAffineMap().getNumDims(); + numNewStrideSymbols += apply.getAffineMap().getNumSymbols(); + continue; + } + + // unsupported index to subview + return failure(); + } + + SmallVector exprs = lgMap.getAffineExprs(); + + for (auto expr : exprs) { + auto newexpr = expr.compose with the start and index above + and also take into account new dims/symbols + } + + lgMap = AffineMap::get(exprs, num total new dims, num total new symbols); + input = SV.getInput(); + + } + + return failure(); + + } + return success(); +} struct AffineForOpRaising : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -347,6 +447,7 @@ struct AffineForOpRaising : public OpRewritePattern { //This captures the indexing map attribute from the linalg.generic being processed ArrayAttr indexingMapsAttr = lg.getIndexingMaps(); + int idx = 0; // Iterate over input arguments for (Value input : lg.getInputs()) { //Is this needed? @@ -355,7 +456,12 @@ struct AffineForOpRaising : public OpRewritePattern { //TODO: Implement this //lgMap comes from offset of memref.subview, //lgOperands comes from operands of memref.subview - getLinalgArgMap(inout, lgMap, lgOperands, lgMemref); + AffineMap lgMap = indexingMapsAttr[idx]; + + auto result = getLinalgArgMap(loop, input, lgMap, lgOperands, lgMemref); + + if (!result.succeeded()) return failure(); + bool legal = true; auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), @@ -366,6 +472,7 @@ struct AffineForOpRaising : public OpRewritePattern { //TODO: need to mergre previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); + idx++; } // Iterate over output arguments @@ -373,7 +480,12 @@ struct AffineForOpRaising : public OpRewritePattern { //Is this needed? if (conds.size() != 0) return failure(); - getLinalgArgMap(output, lgMap, lgOperands, lgMemref); + AffineMap lgMap = indexingMapsAttr[idx]; + + auto result = getLinalgArgMap(loop, output, lgMap, lgOperands, lgMemref); + + if (!result.succeeded()) return failure(); + bool legal = true; auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), @@ -428,9 +540,13 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO Push all of the outputs to the linalg generics // TODO presently if linalg generic exists, assert there are no load/stores - assert((linalgGenerics.size() > 0) ? ((loads.size() == 0 ) && (stores.size() == 0)) : 1); + if(!((linalgGenerics.size() > 0) && ((loads.size() == 0 ) && (stores.size() == 0)))) + return failure; + // TODO assert only zero or one linalg generic exists - assert(linalgGenerics.size() == 1 || linalgGenerics.size() == 0); + if(!(linalgGenerics.size() == 1 || linalgGenerics.size() == 0)) + return failure; + SmallVector iteratorTypes; // TODO if linalg generic exists, make this iterator type prepend to the existing iterators From b57c0b86d5174e3a277f611bca4777ebcdecc349 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 19 Jun 2024 00:10:35 +0000 Subject: [PATCH 05/77] Incremental changes to fuse linalg and for loop- Logic for shifted operands and map for linalg.generic --- lib/polygeist/Passes/RaiseToLinalg.cpp | 128 +++++++++++++++---------- 1 file changed, 80 insertions(+), 48 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 1a539299ac02..402432a07c00 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -215,9 +215,13 @@ affine.for { */ - -LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, lgOperands, lgMemref) { - +// Suppose we have a memref expression E=input[affine.map(operands)] +// if input = memref.subview A[starts, offsets] +// can we rewrite E as A[affine.map2(operands2)] +// We update lgMap and lgOperands in place with this coresponding map2 and operands2 +LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, SmallVector &lgOperands) { + IRBuilder builder(loop->getContext()); + while (Operation *defOp = input.getDefiningOp()) { // If the input is defined outside of the loop, we are finished. @@ -226,67 +230,90 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, l if (auto SV = dyn_cast(defOp)) { // TODO update map with the new indexing from here + + //Create affine map + // i. Track number of running dims and symbols + // ii. shift dims and symbols to generate shifted expressions. + //Extract corresponding operands + //Use affineMap::get with numOperands and numSymbols along with shifted expressions to get a map. + //Use affine map simplify to simplify this + + SmallVector startExprs; + SmallVector strideExprs; + SmallVector dimOperands; + SmallVector symOperands; + for (auto en : llvm::enumerate({SV.getOffsets(), SV.getStrides()})) { + auto &exprOutput = (en.index() == 0) ? startExprs : strideExprs; + for (auto expr : en.value()) { - size_t numNewStartDims = 0; - size_t numNewStartSymbols = 0; - for (auto val : SV->getStarts()) { // Only support constants, symbols, or affine apply as offsets - if (val.getDefinigOp()) { + if (auto cop = val.getDefiningOp()) { + exprOutput.push_back(builder.getAffineConstantExpr(cop.getValue())); continue; } + + if (auto ba = dyn_cast(val)) + if(isa(ba->getParentOp())) { + exprOutput.push_back(builder.getAffineDimExpr(dimOperands.size())); + dimOperands.push_back(ba); + continue; + } + auto valOp = val.getDefiningOp(); // Defined outside loop, consider it a symbol [for now] - if (!valOp || !loop->isAncestor(valOp)) continue; - - if(auto index = dyn_cast<>(valOp)) { - - } - - //Q. If we just extract num dims and symbs- - // i. Won't we miss constant values in the affine map? - // ii. How will we know the relation between dims and syms? - // Eg- affine_map<(d0, d1)[s0] -> (d0 + 2 * d1 + s0, d1 - s0)> - //Also we need to check for unique args and only count them in numNewStartDims and Symbols. - if (auto apply = dyn_cast(valOp)) { - numNewStartDims += apply.getAffineMap().getNumDims(); - numNewStartSymbols += apply.getAffineMap().getNumSymbols(); - newExpr = apply.getResults(); - } - - // unsupported index to subview - return failure(); - } - size_t numNewStrideDims = 0; - size_t numNewStrideSymbols = 0; - for (auto val : SV->getStrides()) { - // Only support constants, symbols, or affine apply as offsets - if (val.getDefinigOp()) { + if (!valOp || loop->isAncestor(defOp)) { + exprOutput.push_back(builder.getAffineSymbolExpr(symOperands.size())); + symOperands.push_back(ba); continue; } - auto valOp = val.getDefiningOp(); - // Defined outside loop, consider it a symbol [for now] - if (!valOp || loop->isAncestor(defOp)) continue; if (auto apply = dyn_cast(val)) { - numNewStrideDims += apply.getAffineMap().getNumDims(); - numNewStrideSymbols += apply.getAffineMap().getNumSymbols(); + auto map = apply.getAffineMap(); + auto newexpr = map. + .shiftDims(dimOperands.size()) + .shiftSymbols(symOperands.size()); + + for (auto expr : newexpr.getResults()) { + exprOutput.push_back(newexpr); + } + + for (size_t i=0; i inputExprs; + for (auto expr : lgMap. + .shiftDims(dimOperands.size()) + .shiftSymbols(symOperands.size()); + getResults()) { + inputExprs.push_back(newexpr); + } + for (size_t i=0; i exprs = lgMap.getAffineExprs(); - for (auto expr : exprs) { - auto newexpr = expr.compose with the start and index above - and also take into account new dims/symbols + SmallVector mergedExprs; + for (auto [start, stride, idx]&& : llvm::zip(startExprs, strideExprs, inputExprs)) { + mergedExprs.push_back(startExprs + idx * strideExpr); } - lgMap = AffineMap::get(exprs, num total new dims, num total new symbols); + lgMap = AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs); + lgOperands.clear(); + lgOperands.append(dimOperands()); + lgOperands.append(symOperands()); input = SV.getInput(); - } return failure(); @@ -457,8 +484,10 @@ struct AffineForOpRaising : public OpRewritePattern { //lgMap comes from offset of memref.subview, //lgOperands comes from operands of memref.subview AffineMap lgMap = indexingMapsAttr[idx]; - - auto result = getLinalgArgMap(loop, input, lgMap, lgOperands, lgMemref); + SmallVector lgOperands; + for (auto i=0; i { if (conds.size() != 0) return failure(); AffineMap lgMap = indexingMapsAttr[idx]; + SmallVector lgOperands; + for (auto i=0; i { if (!legal) return failure(); - //TODO: need to mergre previous indexing maps and new affine maps + //TODO: need to merge previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); } From f54c33d318ade1f7a763b1731a64e11f9c36a3d9 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 25 Jun 2024 23:26:33 +0000 Subject: [PATCH 06/77] ran clang format --- lib/polygeist/Passes/RaiseToLinalg.cpp | 917 ++++++++++++++----------- 1 file changed, 497 insertions(+), 420 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 402432a07c00..983c55218ac0 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -1,13 +1,14 @@ #include "PassDetails.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" @@ -15,7 +16,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "polygeist/Passes/Passes.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/AffineExpr.h" #define DEBUG_TYPE "raise-to-linalg" @@ -26,170 +26,206 @@ using namespace affine; using namespace linalg; namespace { -struct RaiseAffineToLinalg : public AffineRaiseToLinalgBase { +struct RaiseAffineToLinalg + : public AffineRaiseToLinalgBase { void runOnOperation() override; }; } // namespace -// Also want to add support for affine.for ( ) { linalg.generic } -> bigger linalg.generic -// Also probably want to try to do { linalg.generc1(); linalg.generic2(); } -> bigger linalg.generic() +// Also want to add support for affine.for ( ) { linalg.generic } -> bigger +// linalg.generic Also probably want to try to do { linalg.generc1(); +// linalg.generic2(); } -> bigger linalg.generic() /* affine.for() { affine.for() { - } + } affine.for() { } } */ struct Condition { - bool ifTrue; - AffineIfOp op; - Condition(bool ifTrue, AffineIfOp op) : ifTrue(ifTrue), op(op) {} + bool ifTrue; + AffineIfOp op; + Condition(bool ifTrue, AffineIfOp op) : ifTrue(ifTrue), op(op) {} }; bool isLinearInIndex(AffineExpr expr, size_t idx) { - if (!expr.isFunctionOfDim(idx)) { - return true; - } + if (!expr.isFunctionOfDim(idx)) { + return true; + } - if (expr.getKind() == AffineExprKind::DimId) { - return true; - } + if (expr.getKind() == AffineExprKind::DimId) { + return true; + } - if (expr.getKind() == AffineExprKind::Add) { - auto binop = expr.cast(); - return isLinearInIndex(binop.getLHS(), idx) && isLinearInIndex(binop.getRHS(), idx); - } - if (expr.getKind() == AffineExprKind::Mul) { - auto binop = expr.cast(); - return (isLinearInIndex(binop.getLHS(), idx) && !binop.getRHS().isFunctionOfDim(idx)) || - (isLinearInIndex(binop.getRHS(), idx) && !binop.getLHS().isFunctionOfDim(idx)); - } + if (expr.getKind() == AffineExprKind::Add) { + auto binop = expr.cast(); + return isLinearInIndex(binop.getLHS(), idx) && + isLinearInIndex(binop.getRHS(), idx); + } + if (expr.getKind() == AffineExprKind::Mul) { + auto binop = expr.cast(); + return (isLinearInIndex(binop.getLHS(), idx) && + !binop.getRHS().isFunctionOfDim(idx)) || + (isLinearInIndex(binop.getRHS(), idx) && + !binop.getLHS().isFunctionOfDim(idx)); + } - return false; + return false; } bool isLinearInIndex(AffineMap map, size_t idx) { - for (auto expr : map.getResults()) { - if (!isLinearInIndex(expr, idx)) - return false; - } - return true; + for (auto expr : map.getResults()) { + if (!isLinearInIndex(expr, idx)) + return false; + } + return true; } - AffineExpr shiftDimsDown1(AffineExpr expr, unsigned numDims, - unsigned offset) { - SmallVector dims; - for (unsigned idx = 0; idx < offset; ++idx) - dims.push_back(getAffineDimExpr(idx, expr.getContext())); - for (unsigned idx = offset; idx < numDims; ++idx) - dims.push_back(getAffineDimExpr(idx - 1, expr.getContext())); - return expr.replaceDimsAndSymbols(dims, {}); - } - -//This is reducing the number of input dims in expression by 1 - AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, - unsigned offset) { - assert(offset <= expr.getNumDims()); - return AffineMap::get(expr.getNumDims() - 1, expr.getNumSymbols(), - llvm::map_to_vector<4>( - expr.getResults(), - [&](AffineExpr e) { - return shiftDimsDown1(e, expr.getNumDims(), offset); - }), - expr.getContext()); - } - -// Given an affine map `oldmap`, memref `val`, and corresponding input values (which are a list of indicies, then symbols), -// and a loop index `ind` produce the following: -// 1. A (potentially new) memref value `newval` which does not have any dependence on `ind` -// and -// 2. an affine map `newmap` which takes a single index (`ind`) and produces indices into `newval` such that -// indexing `newval[map(ind)]` produces the same result as indexing the original map. -std::pair remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, Value val, Value idx, Value idx_size, int loopLowerBound, int loopStepSize, mlir::OperandRange vals) { - // First we need to remove any dependence on the loop index from the affine map - SmallVector vals_without_idx; - //This tracks the index corresponding to the for loop if present in load/store operands else it's -1 - ssize_t dim_idx = -1; - //To check if induction variable of for loop in an operand of this op (load/store) - for (auto &&[i, v] : llvm::enumerate(vals)) { - if (v == idx) { - // Offset we're replacing must be an index (not a symbol). - // If we guarantee to run AffineCFG first, this should always be true. - assert(i < oldmap.getNumDims()); - // There should only be one use of the index. - assert(dim_idx == -1); - dim_idx = i; - continue; - } - vals_without_idx.push_back(v); - } - - if (dim_idx != -1 && !isLinearInIndex(oldmap, dim_idx)) { - legal = false; - return {val, oldmap}; - } - +AffineExpr shiftDimsDown1(AffineExpr expr, unsigned numDims, unsigned offset) { + SmallVector dims; + for (unsigned idx = 0; idx < offset; ++idx) + dims.push_back(getAffineDimExpr(idx, expr.getContext())); + for (unsigned idx = offset; idx < numDims; ++idx) + dims.push_back(getAffineDimExpr(idx - 1, expr.getContext())); + return expr.replaceDimsAndSymbols(dims, {}); +} - // Evaluate offsets as oldmap replacing idx with 0, and evaluating at the remaining variables +// This is reducing the number of input dims in expression by 1 +AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { + assert(offset <= expr.getNumDims()); + return AffineMap::get(expr.getNumDims() - 1, expr.getNumSymbols(), + llvm::map_to_vector<4>(expr.getResults(), + [&](AffineExpr e) { + return shiftDimsDown1( + e, expr.getNumDims(), + offset); + }), + expr.getContext()); +} - //Instead of lower bound we are using 0 (assumption as the lower bound) - AffineMap offsetMap = oldmap; - if (dim_idx != -1) { - offsetMap = oldmap.replace(builder.getAffineDimExpr(dim_idx), builder.getAffineConstantExpr(loopLowerBound),offsetMap.getNumDims(), offsetMap.getNumSymbols()); - offsetMap = shiftDimsDown1(offsetMap, oldmap.getNumDims(), dim_idx); +// Given an affine map `oldmap`, memref `val`, and corresponding input values +// (which are a list of indicies, then symbols), and a loop index `ind` produce +// the following: +// 1. A (potentially new) memref value `newval` which does not have any +// dependence on `ind` +// and +// 2. an affine map `newmap` which takes a single index (`ind`) and produces +// indices into `newval` such that +// indexing `newval[map(ind)]` produces the same result as indexing the +// original map. +std::pair +remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, + Value val, Value idx, Value idx_size, int loopLowerBound, + int loopStepSize, mlir::OperandRange vals) { + // First we need to remove any dependence on the loop index from the affine + // map + SmallVector vals_without_idx; + // This tracks the index corresponding to the for loop if present in + // load/store operands else it's -1 + ssize_t dim_idx = -1; + // To check if induction variable of for loop in an operand of this op + // (load/store) + for (auto &&[i, v] : llvm::enumerate(vals)) { + if (v == idx) { + // Offset we're replacing must be an index (not a symbol). + // If we guarantee to run AffineCFG first, this should always be true. + assert(i < oldmap.getNumDims()); + // There should only be one use of the index. + assert(dim_idx == -1); + dim_idx = i; + continue; } + vals_without_idx.push_back(v); + } - //Instead of using loop step we are using 1 (Assumption as the stride size) - AffineMap strideMap = oldmap; - if (dim_idx != -1) { - strideMap = oldmap.replace(builder.getAffineDimExpr(dim_idx), builder.getAffineConstantExpr(loopLowerBound + loopStepSize),strideMap.getNumDims(), strideMap.getNumSymbols()); - strideMap = shiftDimsDown1(strideMap, oldmap.getNumDims(), dim_idx); - } + if (dim_idx != -1 && !isLinearInIndex(oldmap, dim_idx)) { + legal = false; + return {val, oldmap}; + } - //Subtracting maps of stride and offset, gives you the offset value in the result of the map - { - SmallVector subtracts; - for (auto &&[lhs, rhs] : llvm::zip(strideMap.getResults(), offsetMap.getResults())) { - subtracts.push_back(lhs - rhs); - } - strideMap = AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), subtracts, builder.getContext()); - } + // Evaluate offsets as oldmap replacing idx with 0, and evaluating at the + // remaining variables + + // Instead of lower bound we are using 0 (assumption as the lower bound) + AffineMap offsetMap = oldmap; + if (dim_idx != -1) { + offsetMap = + oldmap.replace(builder.getAffineDimExpr(dim_idx), + builder.getAffineConstantExpr(loopLowerBound), + offsetMap.getNumDims(), offsetMap.getNumSymbols()); + offsetMap = shiftDimsDown1(offsetMap, oldmap.getNumDims(), dim_idx); + } - // Expression to index into the generated subview given the loop index - SmallVector loop_idxs; + // Instead of using loop step we are using 1 (Assumption as the stride size) + AffineMap strideMap = oldmap; + if (dim_idx != -1) { + strideMap = oldmap.replace( + builder.getAffineDimExpr(dim_idx), + builder.getAffineConstantExpr(loopLowerBound + loopStepSize), + strideMap.getNumDims(), strideMap.getNumSymbols()); + strideMap = shiftDimsDown1(strideMap, oldmap.getNumDims(), dim_idx); + } - // List of starting offsets into the subview - SmallVector offsets; - SmallVector sizes; - SmallVector strides; + // Subtracting maps of stride and offset, gives you the offset value in the + // result of the map + { + SmallVector subtracts; + for (auto &&[lhs, rhs] : + llvm::zip(strideMap.getResults(), offsetMap.getResults())) { + subtracts.push_back(lhs - rhs); + } + strideMap = + AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), + subtracts, builder.getContext()); + } - for (auto &&[expr, offset_expr, stride_expr] : llvm::zip(oldmap.getResults(), offsetMap.getResults(),strideMap.getResults() )) { - offsets.push_back(builder.create(val.getLoc(),AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), offset_expr, builder.getContext()), vals_without_idx)); //What is there are symbols in the expression? - strides.push_back(builder.create(val.getLoc(),AffineMap::get(strideMap.getNumDims(), strideMap.getNumSymbols(), stride_expr, builder.getContext()), vals_without_idx)); //What is there are symbols in the expression? - if (!expr.isFunctionOfDim(dim_idx)) { - loop_idxs.push_back(builder.getAffineConstantExpr(0)); - sizes.push_back(builder.create(val.getLoc(), 1)); - } else { - loop_idxs.push_back(builder.getAffineDimExpr(0)); - sizes.push_back(idx_size); - } + // Expression to index into the generated subview given the loop index + SmallVector loop_idxs; + + // List of starting offsets into the subview + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + + for (auto &&[expr, offset_expr, stride_expr] : + llvm::zip(oldmap.getResults(), offsetMap.getResults(), + strideMap.getResults())) { + offsets.push_back(builder.create( + val.getLoc(), + AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), + offset_expr, builder.getContext()), + vals_without_idx)); // What is there are symbols in the expression? + strides.push_back(builder.create( + val.getLoc(), + AffineMap::get(strideMap.getNumDims(), strideMap.getNumSymbols(), + stride_expr, builder.getContext()), + vals_without_idx)); // What is there are symbols in the expression? + if (!expr.isFunctionOfDim(dim_idx)) { + loop_idxs.push_back(builder.getAffineConstantExpr(0)); + sizes.push_back(builder.create(val.getLoc(), 1)); + } else { + loop_idxs.push_back(builder.getAffineDimExpr(0)); + sizes.push_back(idx_size); } + } - auto newval = builder.create(val.getLoc(), val, offsets, sizes, strides); - legal = true; - //Does this need fix? Here we are constraining to dims as 1 and symbols as 0, should it be, original - return {newval, AffineMap::get(/*dims*/1, /*symbols*/0, loop_idxs, builder.getContext())}; + auto newval = builder.create(val.getLoc(), val, offsets, + sizes, strides); + legal = true; + // Does this need fix? Here we are constraining to dims as 1 and symbols as 0, + // should it be, original + return {newval, AffineMap::get(/*dims*/ 1, /*symbols*/ 0, loop_idxs, + builder.getContext())}; } - // store A[...] // val = load A[...] -/* prevA : +/* prevA : store A val is now prevA */ @@ -215,111 +251,114 @@ affine.for { */ -// Suppose we have a memref expression E=input[affine.map(operands)] +// Suppose we have a memref expression E=input[affine.map(operands)] // if input = memref.subview A[starts, offsets] // can we rewrite E as A[affine.map2(operands2)] -// We update lgMap and lgOperands in place with this coresponding map2 and operands2 -LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, SmallVector &lgOperands) { - IRBuilder builder(loop->getContext()); - - while (Operation *defOp = input.getDefiningOp()) { - - // If the input is defined outside of the loop, we are finished. - if (!loop->isAncestor(defOp)) continue; - - if (auto SV = dyn_cast(defOp)) { - - // TODO update map with the new indexing from here - - //Create affine map - // i. Track number of running dims and symbols - // ii. shift dims and symbols to generate shifted expressions. - //Extract corresponding operands - //Use affineMap::get with numOperands and numSymbols along with shifted expressions to get a map. - //Use affine map simplify to simplify this - - SmallVector startExprs; - SmallVector strideExprs; - SmallVector dimOperands; - SmallVector symOperands; - for (auto en : llvm::enumerate({SV.getOffsets(), SV.getStrides()})) { - auto &exprOutput = (en.index() == 0) ? startExprs : strideExprs; - for (auto expr : en.value()) { - - // Only support constants, symbols, or affine apply as offsets - if (auto cop = val.getDefiningOp()) { - exprOutput.push_back(builder.getAffineConstantExpr(cop.getValue())); - continue; - } - - if (auto ba = dyn_cast(val)) - if(isa(ba->getParentOp())) { - exprOutput.push_back(builder.getAffineDimExpr(dimOperands.size())); - dimOperands.push_back(ba); - continue; - } - - auto valOp = val.getDefiningOp(); - // Defined outside loop, consider it a symbol [for now] - if (!valOp || loop->isAncestor(defOp)) { - exprOutput.push_back(builder.getAffineSymbolExpr(symOperands.size())); - symOperands.push_back(ba); - continue; - } - - if (auto apply = dyn_cast(val)) { - auto map = apply.getAffineMap(); - auto newexpr = map. - .shiftDims(dimOperands.size()) - .shiftSymbols(symOperands.size()); - - for (auto expr : newexpr.getResults()) { - exprOutput.push_back(newexpr); - } - - for (size_t i=0; i &lgOperands) { + OpBuilder builder(loop->getContext()); + + while (Operation *defOp = input.getDefiningOp()) { + + // If the input is defined outside of the loop, we are finished. + if (!loop->isAncestor(defOp)) + continue; + + if (auto SV = dyn_cast(defOp)) { + + // TODO update map with the new indexing from here + + // Create affine map + // i. Track number of running dims and symbols + // ii. shift dims and symbols to generate shifted expressions. + // Extract corresponding operands + // Use affineMap::get with numOperands and numSymbols along with shifted + // expressions to get a map. Use affine map simplify to simplify this + + SmallVector startExprs; + SmallVector strideExprs; + SmallVector dimOperands; + SmallVector symOperands; + for (auto en : llvm::enumerate(SV.getOffsets(), SV.getStrides())) { + auto &exprOutput = (en.index() == 0) ? startExprs : strideExprs; + for (auto expr : en.value()) { + auto val = en.value(); + // Only support constants, symbols, or affine apply as offsets + if (auto cop = val.getDefiningOp()) { + exprOutput.push_back(builder.getAffineConstantExpr(cop.getValue())); + continue; + } + + if (auto ba = dyn_cast(val)) + if (isa(ba->getParentOp())) { + exprOutput.push_back( + builder.getAffineDimExpr(dimOperands.size())); + dimOperands.push_back(ba); + continue; } - SmallVector inputExprs; - for (auto expr : lgMap. - .shiftDims(dimOperands.size()) - .shiftSymbols(symOperands.size()); - getResults()) { - inputExprs.push_back(newexpr); - } - for (size_t i=0; iisAncestor(defOp)) { + exprOutput.push_back( + builder.getAffineSymbolExpr(symOperands.size())); + symOperands.push_back(ba); + continue; + } + if (auto apply = dyn_cast(val)) { + auto map = apply.getAffineMap(); + auto newexpr = map..shiftDims(dimOperands.size()) + .shiftSymbols(symOperands.size()); - SmallVector mergedExprs; - for (auto [start, stride, idx]&& : llvm::zip(startExprs, strideExprs, inputExprs)) { - mergedExprs.push_back(startExprs + idx * strideExpr); + for (auto expr : newexpr.getResults()) { + exprOutput.push_back(newexpr); } - lgMap = AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs); - lgOperands.clear(); - lgOperands.append(dimOperands()); - lgOperands.append(symOperands()); - input = SV.getInput(); - } + for (size_t i = 0; i < map.getNumDims(); i++) + dimOperands.push_back(apply.getOperands()[i]); - return failure(); + for (size_t i = 0; i < map.getNumSymbols(); i++) + symOperands.push_back(apply.getOperands()[i + map.getNumDims()]); + + continue; + } + return failure(); + } + } + + SmallVector inputExprs; + for (auto expr : lgMap.shiftDims(dimOperands.size()) + .shiftSymbols(symOperands.size()); + getResults()) { + inputExprs.push_back(newexpr); + } + for (size_t i = 0; i < lgMap.getNumDims(); i++) + dimOperands.push_back(lgOperands[i]); + + for (size_t i = 0; i < lgMap.getNumSymbols(); i++) + symOperands.push_back(lgOperands[i + lgMap.getNumDims()]); + + SmallVector mergedExprs; + for (auto [start, stride, idx] && : + llvm::zip(startExprs, strideExprs, inputExprs)) { + mergedExprs.push_back(startExprs + idx * strideExpr); + } + + lgMap = + AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs); + lgOperands.clear(); + lgOperands.append(dimOperands()); + lgOperands.append(symOperands()); + input = SV.getInput(); } - return success(); + + return failure(); + } + return success(); } struct AffineForOpRaising : public OpRewritePattern { @@ -331,7 +370,7 @@ struct AffineForOpRaising : public OpRewritePattern { // Don't handle accumulations in registers for the moment, we can have // a separate pattern move them into memref's if (loop.getNumResults() != 0) { - return failure(); + return failure(); } SmallVector, AffineLoadOp>> loads; @@ -343,109 +382,120 @@ struct AffineForOpRaising : public OpRewritePattern { // affine.load, affine.store, affine.if, affine.yield // Additionally, for each load/store, remember what conditions are // required for that load or store to execute. - auto result = loop->walk([&](Operation* op) { - if (op == loop) return WalkResult::advance(); - // TODO extend this, any non-memory operation is also legal here. - // mul, add, etc (we can just check propety) - if (isa(op)) { - return WalkResult::advance(); + auto result = loop->walk([&](Operation *op) { + if (op == loop) + return WalkResult::advance(); + // TODO extend this, any non-memory operation is also legal here. + // mul, add, etc (we can just check propety) + if (isa(op)) { + return WalkResult::advance(); + } + if (isa(op) || isa(op)) { + Operation *cur = op->getParentOp(); + std::vector conditions; + while (cur != loop) { + auto ifstmt = dyn_cast(cur); + if (!ifstmt) { + return WalkResult::interrupt(); + } + bool ifTrue = + ifstmt.getThenRegion().isAncestor(cur->getParentRegion()); + conditions.emplace_back(ifTrue, ifstmt); + cur = ifstmt->getParentOp(); } - if (isa(op) || isa(op)) { - Operation *cur = op->getParentOp(); - std::vector conditions; - while (cur != loop) { - auto ifstmt = dyn_cast(cur); - if (!ifstmt) { - return WalkResult::interrupt(); - } - bool ifTrue = ifstmt.getThenRegion().isAncestor(cur->getParentRegion()); - conditions.emplace_back(ifTrue, ifstmt); - cur = ifstmt->getParentOp(); - } - if (auto linalgGeneric = dyn_cast(op)) { - linalgGenerics.emplace_back(conditions, linalgGeneric); - } - else if (auto load = dyn_cast(op)) { - loads.emplace_back(conditions, load); - } else { - auto store = cast(op); - stores.emplace_back(conditions, store); - } - return WalkResult::advance(); - } - //IsReadNone takes care of apply and subview too? - if (isReadNone(op)) { - return WalkResult::advance(); + if (auto linalgGeneric = dyn_cast(op)) { + linalgGenerics.emplace_back(conditions, linalgGeneric); + } else if (auto load = dyn_cast(op)) { + loads.emplace_back(conditions, load); + } else { + auto store = cast(op); + stores.emplace_back(conditions, store); } - return WalkResult::interrupt(); + return WalkResult::advance(); + } + // IsReadNone takes care of apply and subview too? + if (isReadNone(op)) { + return WalkResult::advance(); + } + return WalkResult::interrupt(); }); - - if (result.wasInterrupted()) return failure(); + + if (result.wasInterrupted()) + return failure(); DominanceInfo DI(loop); - // Check that all of the stores do not alias the loaded values (otherwise we could get an incorrect result) - // TODO we can extend this and handle things like reductions, but we're going to start easy for now - // TODO + // Check that all of the stores do not alias the loaded values (otherwise we + // could get an incorrect result) + // TODO we can extend this and handle things like reductions, but we're + // going to start easy for now + // TODO DenseMap stores_map; for (auto &&[_, store] : stores) { - for (auto &&[_, load]: loads) { - if (mayAlias(load.getMemref(), store.getMemref())) { - // We have one exception in this case -- if the load and store are from the exact same location, it is permitted. - if (load.getMemref() == store.getMemref() && - load.getAffineMap() == store.getAffineMap() && - load.getIndices() == store.getIndices() && DI.dominates((Operation*)load,(Operation*)store)) { - //Example case where load does not dominate stores - if the load was conditional. - //Or, store followed by load? - //Q. Can't we still overlook the aliasing? - stores_map[load] = store; - continue; - } - return failure(); - } + for (auto &&[_, load] : loads) { + if (mayAlias(load.getMemref(), store.getMemref())) { + // We have one exception in this case -- if the load and store are + // from the exact same location, it is permitted. + if (load.getMemref() == store.getMemref() && + load.getAffineMap() == store.getAffineMap() && + load.getIndices() == store.getIndices() && + DI.dominates((Operation *)load, (Operation *)store)) { + // Example case where load does not dominate stores - if the load + // was conditional. Or, store followed by load? Q. Can't we still + // overlook the aliasing? + stores_map[load] = store; + continue; + } + return failure(); } - for (auto &&[_, store2]: stores) { - if (store == store2) continue; - if (mayAlias(store.getMemref(), store2.getMemref())) { - return failure(); - } + } + for (auto &&[_, store2] : stores) { + if (store == store2) + continue; + if (mayAlias(store.getMemref(), store2.getMemref())) { + return failure(); } + } } // Check that any other loads / stores do not alias with any linalg generics - // We're going to need to upgrade the defn of mayAlias for subviews (aka mayAlias(subview, x) -> mayAlias(operand(subview), x)) + // We're going to need to upgrade the defn of mayAlias for subviews (aka + // mayAlias(subview, x) -> mayAlias(operand(subview), x)) SmallVector inputs; SmallVector affineMaps; SmallVector indexingMaps; - //if (loop.getStep() != 1) { - // return failure(); - //} + // if (loop.getStep() != 1) { + // return failure(); + // } - // our remapper currently assumes 0 start to bound. + // our remapper currently assumes 0 start to bound. if (!loop.hasConstantLowerBound() /*|| loop.getConstantLowerBound() != 0*/) { - return failure(); + return failure(); } // compute this correctly later. auto ubMap = loop.getUpperBoundMap(); auto ubOperands = loop.getUpperBoundOperands(); - if (!ubMap || ubMap.getNumResults() != 1) return failure(); + if (!ubMap || ubMap.getNumResults() != 1) + return failure(); // Retrieve the lower bound auto lbMap = loop.getLowerBoundMap(); auto lbOperands = loop.getLowerBoundOperands(); - if (!lbMap || lbMap.getNumResults() != 1) return failure(); - + if (!lbMap || lbMap.getNumResults() != 1) + return failure(); + auto ub = loop.getSingleUpperBound(); - if (!ub) return failure(); + if (!ub) + return failure(); auto lb = loop.getSingleLowerBound(); - if (!lb) return failure(); - + if (!lb) + return failure(); if (!loop.hasConstantUpperBound()) { - return failure(); + return failure(); } // Retrieve the step size @@ -453,192 +503,219 @@ struct AffineForOpRaising : public OpRewritePattern { // Get the single result expressions AffineExpr ubExpr = ubMap.getResult(0); - auto ubValue = rewriter.create(loop.getLoc(), ubMap, ubOperands); - + auto ubValue = + rewriter.create(loop.getLoc(), ubMap, ubOperands); + AffineExpr lbExpr = lbMap.getResult(0); - auto lbValue = rewriter.create(loop.getLoc(), lbMap, lbOperands); + auto lbValue = + rewriter.create(loop.getLoc(), lbMap, lbOperands); //// Ensure the bounds are constant expressions auto ubConst = ubExpr.dyn_cast(); auto lbConst = lbExpr.dyn_cast(); - if (!ubConst || !lbConst) return failure(); + if (!ubConst || !lbConst) + return failure(); // Compute the loop size - //int64_t loopSize = ubConst.getValue() - lbConst.getValue(); + // int64_t loopSize = ubConst.getValue() - lbConst.getValue(); auto loopSize = rewriter.create(loop.getLoc(), ubValue, lbValue); - - //Value loopSize = rewriter.create(loop.getLoc(), loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), *ub, *lb); - + + // Value loopSize = rewriter.create(loop.getLoc(), + // loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), + // *ub, *lb); + for (auto &&[conds, lg] : linalgGenerics) { - - //This captures the indexing map attribute from the linalg.generic being processed - ArrayAttr indexingMapsAttr = lg.getIndexingMaps(); - - int idx = 0; - // Iterate over input arguments - for (Value input : lg.getInputs()) { - //Is this needed? - if (conds.size() != 0) return failure(); - - //TODO: Implement this - //lgMap comes from offset of memref.subview, - //lgOperands comes from operands of memref.subview - AffineMap lgMap = indexingMapsAttr[idx]; - SmallVector lgOperands; - for (auto i=0; i lgOperands; - for (auto i=0; i lgOperands; + for (auto i = 0; i < lgMap.getNumDims(); i++) + lgOperands.push_back(builder.getAffineDim(i)); + Value lgMemref = input; + auto result = getLinalgArgMap(loop, lgMemref, lgMap, lgOperands); + + if (!result.succeeded()) + return failure(); - if (stores_map.find(load) != stores_map.end()) { - // We have a store that represents this load. - continue; - } + bool legal = true; + + auto &&[newMemref, newAffineMap] = remap_in_affine_dim( + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, + lbConst.getValue(), step, lgOperands); + + if (!legal) + return failure(); + + // TODO: need to mergre previous indexing maps and new affine maps + affineMaps.push_back(newAffineMap); + inputs.push_back(newMemref); + idx++; + } + + // Iterate over output arguments + for (Value output : lg.getOutputs()) { + // Is this needed? + if (conds.size() != 0) + return failure(); + + AffineMap lgMap = indexingMapsAttr[idx]; + SmallVector lgOperands; + for (auto i = 0; i < lgMap.getNumDims(); i++) + lgOperands.push_back(builder.getAffineDim(i)); + Value lgMemref = output; + + auto result = getLinalgArgMap(loop, lgMemref, lgMap, lgOperands); + + if (!result.succeeded()) + return failure(); bool legal = true; - - auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, load.getAffineMap(), load.getMemref(), loop.getInductionVar(), - loopSize, lbConst.getValue(), step, load.getMapOperands()); - if (!legal) return failure(); + auto &&[newMemref, newAffineMap] = remap_in_affine_dim( + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, + lbConst.getValue(), step, lgOperands); + + if (!legal) + return failure(); + // TODO: need to merge previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); + } + } + + // current spec is going to be indexed off of the loop var in isolation + for (auto &&[conds, load] : loads) { + // Only support unconditional loads for the moment + if (conds.size() != 0) + return failure(); + + if (stores_map.find(load) != stores_map.end()) { + // We have a store that represents this load. + continue; + } + + bool legal = true; + + auto &&[newMemref, newAffineMap] = remap_in_affine_dim( + legal, rewriter, load.getAffineMap(), load.getMemref(), + loop.getInductionVar(), loopSize, lbConst.getValue(), step, + load.getMapOperands()); + + if (!legal) + return failure(); + + affineMaps.push_back(newAffineMap); + inputs.push_back(newMemref); } - // TODO Push all of the inputs to the linalg generics (modifying maps as needed) - + // TODO Push all of the inputs to the linalg generics (modifying maps as + // needed) + SmallVector outputs; - // Store we may need to reindex into a splat potentially later, but for now we'll be lazy + // Store we may need to reindex into a splat potentially later, but for now + // we'll be lazy for (auto &&[conds, store] : stores) { - // Only support unconditional loads for the moment - if (conds.size() != 0) return failure(); + // Only support unconditional loads for the moment + if (conds.size() != 0) + return failure(); - bool legal = true; - - auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, store.getAffineMap(), store.getMemref(), loop.getInductionVar(), - loopSize, lbConst.getValue(), step, store.getMapOperands()); + bool legal = true; - if (!legal) return failure(); + auto &&[newMemref, newAffineMap] = remap_in_affine_dim( + legal, rewriter, store.getAffineMap(), store.getMemref(), + loop.getInductionVar(), loopSize, lbConst.getValue(), step, + store.getMapOperands()); - affineMaps.push_back(newAffineMap); - outputs.push_back(newMemref); + if (!legal) + return failure(); + + affineMaps.push_back(newAffineMap); + outputs.push_back(newMemref); } // TODO Push all of the outputs to the linalg generics // TODO presently if linalg generic exists, assert there are no load/stores - if(!((linalgGenerics.size() > 0) && ((loads.size() == 0 ) && (stores.size() == 0)))) - return failure; + if (!((linalgGenerics.size() > 0) && + ((loads.size() == 0) && (stores.size() == 0)))) + return failure; // TODO assert only zero or one linalg generic exists - if(!(linalgGenerics.size() == 1 || linalgGenerics.size() == 0)) - return failure; - + if (!(linalgGenerics.size() == 1 || linalgGenerics.size() == 0)) + return failure; SmallVector iteratorTypes; - // TODO if linalg generic exists, make this iterator type prepend to the existing iterators - iteratorTypes.push_back((stores_map.size() == 0) ? utils::IteratorType::parallel : utils::IteratorType::reduction); - - + // TODO if linalg generic exists, make this iterator type prepend to the + // existing iterators + iteratorTypes.push_back((stores_map.size() == 0) + ? utils::IteratorType::parallel + : utils::IteratorType::reduction); StringAttr empty = StringAttr::get(loop.getContext()); auto genericOp = rewriter.create( - loop.getLoc(), TypeRange(), inputs, outputs, affineMaps, iteratorTypes, - empty, - empty); + loop.getLoc(), TypeRange(), inputs, outputs, affineMaps, iteratorTypes, + empty, empty); - // TODO if doing the linalg generic case, ignore a lot of the below and instead of injecting the old body of the affine.for, move the inner linalg.generic body - // and also add a new induction variable + // TODO if doing the linalg generic case, ignore a lot of the below and + // instead of injecting the old body of the affine.for, move the inner + // linalg.generic body and also add a new induction variable auto blk = &*loop.getRegion().begin(); rewriter.setInsertionPointToStart(blk); // This index will replace the use of the affine index - auto idx = rewriter.create(loop.getLoc(), rewriter.getIndexAttr(0)); + auto idx = rewriter.create(loop.getLoc(), + rewriter.getIndexAttr(0)); rewriter.replaceAllUsesWith(loop.getInductionVar(), idx); auto &body = genericOp.getRegion(); body.takeBody(loop.getRegion()); - blk->eraseArguments(0, blk->getNumArguments()); for (auto &&[conds, load] : loads) { - if (stores_map.find(load) != stores_map.end()) { - // We have a store that represents this load. - continue; - } - auto arg = blk->addArgument(load.getType(), load.getLoc()); - rewriter.replaceOp(load, arg); - + if (stores_map.find(load) != stores_map.end()) { + // We have a store that represents this load. + continue; + } + auto arg = blk->addArgument(load.getType(), load.getLoc()); + rewriter.replaceOp(load, arg); } for (auto &&[conds, store] : stores) { - auto arg = blk->addArgument(store.getValueToStore().getType(), store.getLoc()); + auto arg = + blk->addArgument(store.getValueToStore().getType(), store.getLoc()); - SmallVector inverted; - for (auto && [map_load, map_store] : stores_map) { - if (map_store == store) { - inverted.push_back(map_load); - } - } - for (size_t i=0; i inverted; + for (auto &&[map_load, map_store] : stores_map) { + if (map_store == store) { + inverted.push_back(map_load); } + } + for (size_t i = 0; i < inverted.size(); i++) { + stores_map.erase(inverted[i]); + auto tmp = inverted[i]; + inverted[i] = nullptr; + rewriter.replaceOp(tmp, arg); + } } SmallVector toreturn; for (auto &&[conds, store] : stores) { - toreturn.push_back(store.getValueToStore()); - rewriter.eraseOp(store); + toreturn.push_back(store.getValueToStore()); + rewriter.eraseOp(store); } rewriter.eraseOp(blk->getTerminator()); From 56e2c54fc350137869a7f712f0725288c813f4ca Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 25 Jun 2024 23:58:22 +0000 Subject: [PATCH 07/77] some compile time fixes --- lib/polygeist/Passes/RaiseToLinalg.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 983c55218ac0..39a42a00b733 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -281,18 +281,20 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, SmallVector strideExprs; SmallVector dimOperands; SmallVector symOperands; - for (auto en : llvm::enumerate(SV.getOffsets(), SV.getStrides())) { - auto &exprOutput = (en.index() == 0) ? startExprs : strideExprs; - for (auto expr : en.value()) { - auto val = en.value(); + for (auto &&[first, second] : llvm::zip(SV.getOffsets(), SV.getStrides())) { + for (auto &&[index, val] : llvm::enumerate(SmallVector({first, second}))) { + auto &exprOutput = (index == 0) ? startExprs : strideExprs; // Only support constants, symbols, or affine apply as offsets - if (auto cop = val.getDefiningOp()) { + if (auto cop = val.getDefiningOp()) { + exprOutput.push_back(builder.getAffineConstantExpr(cop.getValue())); + continue; + } else if (auto cop = val.getDefiningOp()) { exprOutput.push_back(builder.getAffineConstantExpr(cop.getValue())); continue; } if (auto ba = dyn_cast(val)) - if (isa(ba->getParentOp())) { + if (isa(ba.getParentOp())) { exprOutput.push_back( builder.getAffineDimExpr(dimOperands.size())); dimOperands.push_back(ba); @@ -334,7 +336,7 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, for (auto expr : lgMap.shiftDims(dimOperands.size()) .shiftSymbols(symOperands.size()); getResults()) { - inputExprs.push_back(newexpr); + inputExprs.push_back(expr); } for (size_t i = 0; i < lgMap.getNumDims(); i++) dimOperands.push_back(lgOperands[i]); From e2530400f5cc96afa3248578eab8ba61099dcbc7 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 2 Jul 2024 22:56:37 +0000 Subject: [PATCH 08/77] Some compile fixes --- lib/polygeist/Passes/RaiseToLinalg.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 39a42a00b733..1c14da2fe68b 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -292,7 +292,6 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, exprOutput.push_back(builder.getAffineConstantExpr(cop.getValue())); continue; } - if (auto ba = dyn_cast(val)) if (isa(ba.getParentOp())) { exprOutput.push_back( @@ -312,11 +311,11 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, if (auto apply = dyn_cast(val)) { auto map = apply.getAffineMap(); - auto newexpr = map..shiftDims(dimOperands.size()) + auto newexpr = map.shiftDims(dimOperands.size()) .shiftSymbols(symOperands.size()); for (auto expr : newexpr.getResults()) { - exprOutput.push_back(newexpr); + exprOutput.push_back(expr); } for (size_t i = 0; i < map.getNumDims(); i++) @@ -345,9 +344,9 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, symOperands.push_back(lgOperands[i + lgMap.getNumDims()]); SmallVector mergedExprs; - for (auto [start, stride, idx] && : + for (auto && [start, stride, idx] : llvm::zip(startExprs, strideExprs, inputExprs)) { - mergedExprs.push_back(startExprs + idx * strideExpr); + mergedExprs.push_back(start + idx * stride); } lgMap = @@ -711,7 +710,6 @@ struct AffineForOpRaising : public OpRewritePattern { inverted[i] = nullptr; rewriter.replaceOp(tmp, arg); } - } SmallVector toreturn; From e99b8a58a27c6fd965d791d3bac14538083512dc Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 3 Jul 2024 00:11:41 +0000 Subject: [PATCH 09/77] Fixed all the compilation issues. Sample MLIR not raised --- lib/polygeist/Passes/RaiseToLinalg.cpp | 47 ++++++++++++++------------ 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 1c14da2fe68b..bb75bffd9af0 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -120,7 +120,7 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { std::pair remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, Value val, Value idx, Value idx_size, int loopLowerBound, - int loopStepSize, mlir::OperandRange vals) { + int loopStepSize, ValueRange vals) { // First we need to remove any dependence on the loop index from the affine // map SmallVector vals_without_idx; @@ -286,30 +286,33 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, auto &exprOutput = (index == 0) ? startExprs : strideExprs; // Only support constants, symbols, or affine apply as offsets if (auto cop = val.getDefiningOp()) { - exprOutput.push_back(builder.getAffineConstantExpr(cop.getValue())); + exprOutput.push_back(builder.getAffineConstantExpr(cop.value())); continue; } else if (auto cop = val.getDefiningOp()) { - exprOutput.push_back(builder.getAffineConstantExpr(cop.getValue())); + exprOutput.push_back(builder.getAffineConstantExpr(cop.value())); continue; } - if (auto ba = dyn_cast(val)) - if (isa(ba.getParentOp())) { + if (auto ba = dyn_cast(val)) { + Block *parentBlock = ba.getOwner(); + if (isa(parentBlock->getParentOp())) { exprOutput.push_back( builder.getAffineDimExpr(dimOperands.size())); dimOperands.push_back(ba); continue; + } + } auto valOp = val.getDefiningOp(); // Defined outside loop, consider it a symbol [for now] if (!valOp || loop->isAncestor(defOp)) { exprOutput.push_back( builder.getAffineSymbolExpr(symOperands.size())); - symOperands.push_back(ba); + symOperands.push_back(val); continue; } - if (auto apply = dyn_cast(val)) { + if (auto apply = dyn_cast(valOp)) { auto map = apply.getAffineMap(); auto newexpr = map.shiftDims(dimOperands.size()) .shiftSymbols(symOperands.size()); @@ -333,8 +336,7 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, SmallVector inputExprs; for (auto expr : lgMap.shiftDims(dimOperands.size()) - .shiftSymbols(symOperands.size()); - getResults()) { + .shiftSymbols(symOperands.size()).getResults()) { inputExprs.push_back(expr); } for (size_t i = 0; i < lgMap.getNumDims(); i++) @@ -350,11 +352,11 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, } lgMap = - AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs); + AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs, loop->getContext()); lgOperands.clear(); - lgOperands.append(dimOperands()); - lgOperands.append(symOperands()); - input = SV.getInput(); + lgOperands.insert(lgOperands.begin(), dimOperands.begin(), dimOperands.end()); + lgOperands.insert(lgOperands.begin(), symOperands.begin(), symOperands.end()); + input = SV.getSource(); } return failure(); @@ -541,10 +543,11 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO: Implement this // lgMap comes from offset of memref.subview, // lgOperands comes from operands of memref.subview - AffineMap lgMap = indexingMapsAttr[idx]; + AffineMap lgMap = cast(indexingMapsAttr[idx]).getAffineMap(); SmallVector lgOperands; - for (auto i = 0; i < lgMap.getNumDims(); i++) - lgOperands.push_back(builder.getAffineDim(i)); + lgOperands.push_back(input); + // for (auto i = 0; i < lgMap.getNumDims(); i++) + // lgOperands.push_back(lgMap.getOperands()[i]); Value lgMemref = input; auto result = getLinalgArgMap(loop, lgMemref, lgMap, lgOperands); @@ -572,10 +575,11 @@ struct AffineForOpRaising : public OpRewritePattern { if (conds.size() != 0) return failure(); - AffineMap lgMap = indexingMapsAttr[idx]; + AffineMap lgMap = cast(indexingMapsAttr[idx]).getAffineMap(); SmallVector lgOperands; - for (auto i = 0; i < lgMap.getNumDims(); i++) - lgOperands.push_back(builder.getAffineDim(i)); + lgOperands.push_back(output); + // for (auto i = 0; i < lgMap.getNumDims(); i++) + // lgOperands.push_back(lgMap.getSubMap(i)); Value lgMemref = output; auto result = getLinalgArgMap(loop, lgMemref, lgMap, lgOperands); @@ -651,11 +655,11 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO presently if linalg generic exists, assert there are no load/stores if (!((linalgGenerics.size() > 0) && ((loads.size() == 0) && (stores.size() == 0)))) - return failure; + return failure(); // TODO assert only zero or one linalg generic exists if (!(linalgGenerics.size() == 1 || linalgGenerics.size() == 0)) - return failure; + return failure(); SmallVector iteratorTypes; // TODO if linalg generic exists, make this iterator type prepend to the @@ -710,6 +714,7 @@ struct AffineForOpRaising : public OpRewritePattern { inverted[i] = nullptr; rewriter.replaceOp(tmp, arg); } + } SmallVector toreturn; From 34f595c63c2078a421708dcc7df29e688444f830 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 16 Jul 2024 00:18:30 +0000 Subject: [PATCH 10/77] Bug fixes, generating some output at getLinalgArgMap --- lib/polygeist/Passes/RaiseToLinalg.cpp | 97 ++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 7 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index bb75bffd9af0..4ab31eea86ed 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -238,17 +238,74 @@ f(%memref ) affine.for { - %inp = .. subview %memref [ ... ] linalg.generic %inp #map { + body() + } +} + + +-> + +affine.for j { + + linalg.generic %memref #map2(j) { + body() } } + + #map2 = #map with the indexing done to %inp + + + + +%memref = .. subview %memref_base [ ... ] + +linalg.generic %[[[memref]]] [[[[#map]]]]([[[[operands]]]]) { + body() +} + +-> + + +output_memref = memref_base +output_map = subvmap() + + compose +# uts are memref, map, and operands +# outputs are o +memref[map(operands)] ==== output_memref[output_map(output_operands)] + + + +bas= memref<40x40> + +B + +u + +tput_memref, output_map and output_operands +# possible intermediate is ... + +getLinalgArgMap(memref, map, operands to map [e.g. input symbols/dims]) + if memref is alloca/unknown/etc + return memref/map/operands + else + memref = subview memref_base[map2(operands2)] + + return memref_base and a new output_map such that + memref_base[output_map(output_operands)] === memref[map(operands)] + + + + + */ // Suppose we have a memref expression E=input[affine.map(operands)] @@ -305,13 +362,16 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, auto valOp = val.getDefiningOp(); // Defined outside loop, consider it a symbol [for now] - if (!valOp || loop->isAncestor(defOp)) { + //if (!valOp || loop->isAncestor(defOp)) { + if (valOp&&!loop->isAncestor(defOp)) { exprOutput.push_back( builder.getAffineSymbolExpr(symOperands.size())); symOperands.push_back(val); continue; } + //TODO: Maybe it's a case to add, but are we sure we need it for starts and offsets + // and not for operands if (auto apply = dyn_cast(valOp)) { auto map = apply.getAffineMap(); auto newexpr = map.shiftDims(dimOperands.size()) @@ -330,7 +390,7 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, continue; } - return failure(); + //return failure(); } } @@ -345,6 +405,7 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, for (size_t i = 0; i < lgMap.getNumSymbols(); i++) symOperands.push_back(lgOperands[i + lgMap.getNumDims()]); + SmallVector mergedExprs; for (auto && [start, stride, idx] : llvm::zip(startExprs, strideExprs, inputExprs)) { @@ -355,11 +416,12 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs, loop->getContext()); lgOperands.clear(); lgOperands.insert(lgOperands.begin(), dimOperands.begin(), dimOperands.end()); - lgOperands.insert(lgOperands.begin(), symOperands.begin(), symOperands.end()); + lgOperands.insert(lgOperands.begin()+lgOperands.size(), symOperands.begin(), symOperands.end()); input = SV.getSource(); + break; } - return failure(); + //return failure(); } return success(); } @@ -369,6 +431,8 @@ struct AffineForOpRaising : public OpRewritePattern { LogicalResult matchAndRewrite(affine::AffineForOp loop, PatternRewriter &rewriter) const final { + + auto module = loop->getParentOfType(); // Don't handle accumulations in registers for the moment, we can have // a separate pattern move them into memref's @@ -549,6 +613,12 @@ struct AffineForOpRaising : public OpRewritePattern { // for (auto i = 0; i < lgMap.getNumDims(); i++) // lgOperands.push_back(lgMap.getOperands()[i]); Value lgMemref = input; + + // At input, this contains, current input (i.e. probably a subview) + // an lgMap which is obtained from LG's indexing map for corresponding input + // lgOperands contains current input (i.e probably a subview) + + // Gives output ... auto result = getLinalgArgMap(loop, lgMemref, lgMap, lgOperands); if (!result.succeeded()) @@ -556,6 +626,19 @@ struct AffineForOpRaising : public OpRewritePattern { bool legal = true; + // Takes input's/output's, affineMap of load/store (here lgMap ?), + // induction variable corresponding to the loop + // Memref corresponding the the memory accessed (in this case subview ?) + // loopSize, lower and upper bounds + // Get operands for load/store (here ?) to find dependent dim + + // Gives output newMemref which is a subviewOp, + // newAffineMap which is the LG's indexing map corresponding this inp/output + + // This takes load and store maps and then creates affine.apply+subview+linalg.generic + // For this case: LG within ForOp - + // Inputs should be : load map extracted from subviewOp + // Returns LG with indexingMap and subview with affine.apply - which are correct auto &&[newMemref, newAffineMap] = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, lbConst.getValue(), step, lgOperands); @@ -653,8 +736,8 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO Push all of the outputs to the linalg generics // TODO presently if linalg generic exists, assert there are no load/stores - if (!((linalgGenerics.size() > 0) && - ((loads.size() == 0) && (stores.size() == 0)))) + if ((linalgGenerics.size() > 0) && + ((loads.size() == 0) && (stores.size() == 0))) return failure(); // TODO assert only zero or one linalg generic exists From 05bad9756d111657b72e919734426e3067edd64c Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 17 Jul 2024 00:29:06 +0000 Subject: [PATCH 11/77] Almost implementated remap in affine dim for multi idx --- lib/polygeist/Passes/RaiseToLinalg.cpp | 133 +++++++++++++++++-------- 1 file changed, 90 insertions(+), 43 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 4ab31eea86ed..545bc2131e3d 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -108,22 +108,24 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { } // Given an affine map `oldmap`, memref `val`, and corresponding input values -// (which are a list of indicies, then symbols), and a loop index `ind` produce +// (which are a list of indicies, then symbols), and a set of loop indices `indices` produce // the following: // 1. A (potentially new) memref value `newval` which does not have any -// dependence on `ind` +// dependence on `indncides` // and -// 2. an affine map `newmap` which takes a single index (`ind`) and produces +// 2. an affine map `newmap` which takes size(indices) values (`indices`) and produces // indices into `newval` such that -// indexing `newval[map(ind)]` produces the same result as indexing the +// indexing `newval[map(indices)]` produces the same result as indexing the // original map. std::pair remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, - Value val, Value idx, Value idx_size, int loopLowerBound, + Value val, SmallVectorImpl& indices, SmallVector idx_sizes, int loopLowerBound, int loopStepSize, ValueRange vals) { // First we need to remove any dependence on the loop index from the affine // map - SmallVector vals_without_idx; + SmallVector dims; + + for (auto idx : indices) { // This tracks the index corresponding to the for loop if present in // load/store operands else it's -1 ssize_t dim_idx = -1; @@ -139,40 +141,59 @@ remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, dim_idx = i; continue; } - vals_without_idx.push_back(v); } if (dim_idx != -1 && !isLinearInIndex(oldmap, dim_idx)) { legal = false; return {val, oldmap}; } + dims.push_back(dim_idx); + } - // Evaluate offsets as oldmap replacing idx with 0, and evaluating at the - // remaining variables + SmallVector vals_without_indices; + for (auto v : vals) { + if (!llvm::is_contained(indices, v)) + vals_without_indices.push_back(v); + } - // Instead of lower bound we are using 0 (assumption as the lower bound) + // Evaluate offsets as oldmap replacing all indices with 0, and evaluating at the + // remaining variables AffineMap offsetMap = oldmap; - if (dim_idx != -1) { - offsetMap = - oldmap.replace(builder.getAffineDimExpr(dim_idx), - builder.getAffineConstantExpr(loopLowerBound), - offsetMap.getNumDims(), offsetMap.getNumSymbols()); - offsetMap = shiftDimsDown1(offsetMap, oldmap.getNumDims(), dim_idx); + for (auto dim_idx : dims) { + if (dim_idx != -1) { + offsetMap = + oldmap.replace(builder.getAffineDimExpr(dim_idx), + builder.getAffineConstantExpr(loopLowerBound), + offsetMap.getNumDims(), offsetMap.getNumSymbols()); + offsetMap = shiftDimsDown1(offsetMap, oldmap.getNumDims(), dim_idx); + } } - // Instead of using loop step we are using 1 (Assumption as the stride size) - AffineMap strideMap = oldmap; - if (dim_idx != -1) { - strideMap = oldmap.replace( - builder.getAffineDimExpr(dim_idx), - builder.getAffineConstantExpr(loopLowerBound + loopStepSize), - strideMap.getNumDims(), strideMap.getNumSymbols()); - strideMap = shiftDimsDown1(strideMap, oldmap.getNumDims(), dim_idx); - } + SmallVector strideMaps; + + // For each dimension `outer_dim_idx` we want to keep, + // create a new affine map equal to the map(dim=1, other dims=0) + for (auto outer_dim_idx : dims) { + AffineMap strideMap = oldmap; + if (outer_dim_idx != -1) { + strideMap = oldmap.replace( + builder.getAffineDimExpr(outer_dim_idx), + builder.getAffineConstantExpr(loopLowerBound + loopStepSize), + strideMap.getNumDims(), strideMap.getNumSymbols()); + strideMap = shiftDimsDown1(strideMap, oldmap.getNumDims(), outer_dim_idx); + } + for (auto dim_idx : dims) { + if (dim_idx == outer_dim_idx || dim_idx == -1) continue; + + offsetMap = + oldmap.replace(builder.getAffineDimExpr(dim_idx), + builder.getAffineConstantExpr(loopLowerBound), + offsetMap.getNumDims(), offsetMap.getNumSymbols()); + offsetMap = shiftDimsDown1(offsetMap, oldmap.getNumDims(), dim_idx); + } - // Subtracting maps of stride and offset, gives you the offset value in the - // result of the map - { + // Subtracting maps of stride and offset, gives you the offset value in the + // result of the map SmallVector subtracts; for (auto &&[lhs, rhs] : llvm::zip(strideMap.getResults(), offsetMap.getResults())) { @@ -181,40 +202,61 @@ remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, strideMap = AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), subtracts, builder.getContext()); + strideMaps.push_back(strideMap); } - // Expression to index into the generated subview given the loop index - SmallVector loop_idxs; // List of starting offsets into the subview SmallVector offsets; - SmallVector sizes; - SmallVector strides; - - for (auto &&[expr, offset_expr, stride_expr] : - llvm::zip(oldmap.getResults(), offsetMap.getResults(), - strideMap.getResults())) { + for (auto &&[expr, offset_expr] : llvm::zip(oldmap.getResults(), offsetMap.getResults())) { offsets.push_back(builder.create( val.getLoc(), AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), offset_expr, builder.getContext()), - vals_without_idx)); // What is there are symbols in the expression? + vals_without_indices)); // What is there are symbols in the expression? + } + + SmallVector sizes; + SmallVector strides; + + // Expression to index into the generated subview given the loop index + SmallVector loop_idxs; + SmallVector sizes; + for (auto &&[dim_idx, idx_size] : llvm::zip(dims, idx_sizes)) { + if (!oldmap.isFunctionOfDim(dim_idx)) { + loop_idxs.push_back(builder.getAffineConstantExpr(0)); + sizes.push_back(builder.create(val.getLoc(), 1)); + } else { + loop_idxs.push_back(builder.getAffineConstantExpr(0)); + } + } + + for (auto &&[i, expr] : + llvm::enumerate(oldmap.getResults())) { + + AffineExpr stride_expr = nullptr; + for (auto strideMap : strideMaps) { + auto subexpr = strideMap.getResult(i); + if (stride_expr == nullptr) stride_expr = subexpr; + else stride_expr = stride_expr + subexpr; + } + strides.push_back(builder.create( val.getLoc(), - AffineMap::get(strideMap.getNumDims(), strideMap.getNumSymbols(), + AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), stride_expr, builder.getContext()), - vals_without_idx)); // What is there are symbols in the expression? + vals_without_indices)); // What is there are symbols in the expression? + + // These need to be properly computed + // This is the remainign hard part to factor if (!expr.isFunctionOfDim(dim_idx)) { - loop_idxs.push_back(builder.getAffineConstantExpr(0)); sizes.push_back(builder.create(val.getLoc(), 1)); } else { - loop_idxs.push_back(builder.getAffineDimExpr(0)); sizes.push_back(idx_size); } } - auto newval = builder.create(val.getLoc(), val, offsets, - sizes, strides); + auto newval = builder.create(val.getLoc(), val, remap, vals_without_indices, sizes); legal = true; // Does this need fix? Here we are constraining to dims as 1 and symbols as 0, // should it be, original @@ -374,6 +416,11 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, // and not for operands if (auto apply = dyn_cast(valOp)) { auto map = apply.getAffineMap(); + auto *scope = affine::getAffineScope(valOp)->getParentOp(); + DominanceInfo DI(scope); + auto map_operands = apply.getOperands(); + //fully2ComposeAffineMapAndOperands(builder, &map, &map_operands, DI); + // Instead of using loop step we are using 1 (Assumption as the stride size) auto newexpr = map.shiftDims(dimOperands.size()) .shiftSymbols(symOperands.size()); From 5bbf5ef2f4e5f0ce1b77f4424a0098ea0ae4a523 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 24 Jul 2024 00:56:17 +0000 Subject: [PATCH 12/77] Added submap op support and refactored the code to use submap --- lib/polygeist/Ops.cpp | 47 +++++ lib/polygeist/Passes/RaiseToLinalg.cpp | 234 +++++++++---------------- 2 files changed, 134 insertions(+), 147 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index d9a60fbcce45..7a4101f0a936 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5880,3 +5880,50 @@ LogicalResult GetFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } + +/* +class LoadSubMap final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineLoadOp op, + PatternRewriter &rewriter) const override { + auto subMapOp = op.getMemRef().getDefiningOp(); + if (!subMapOp) return failure(); + + auto submap_map = ref.getAffineMap(); + auto submap_operands = ref.getAffineMapOperands(); + auto source_memref = ref.getMemref(); + + auto load_map = ref.getAffineMap(); + SmallVector operands0 = op.getMapOperands(); + + // %m = polygeist.submap submap_map(%submap_operands) %source_memref : memref -> memref + // %a = affine.load %m[load_map(%load_operands)] + // -> + // %a = affine.load %source_memref[load_map(submap_map(%load_operands, %submap_operands))] + + auto new_map = load_map.compose(submap_map); + auto new_operands = llvm::concat(load_operands, submap_operands) + + rewriter.replaceOpWithNewOp(op.getLoc(), sourceMemref, ); + + + + // shift one map over by the size of other # symbols/dims, replace with new affine load with composed map + return success(); + } +}; +*/ +// TODO StoreSubMap + +OpFoldResult mlir::polygeist::SubmapOp::fold(mlir::polygeist::SubmapOp::FoldAdaptor adaptor) { + // TODO if submap is identity return nothing + // if submap of submap return new submap + return nullptr; +} + +void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + //results.insert(context); +} diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 545bc2131e3d..59b443f71835 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -111,157 +111,59 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { // (which are a list of indicies, then symbols), and a set of loop indices `indices` produce // the following: // 1. A (potentially new) memref value `newval` which does not have any -// dependence on `indncides` +// dependence on `indices` // and // 2. an affine map `newmap` which takes size(indices) values (`indices`) and produces // indices into `newval` such that // indexing `newval[map(indices)]` produces the same result as indexing the // original map. -std::pair -remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, - Value val, SmallVectorImpl& indices, SmallVector idx_sizes, int loopLowerBound, - int loopStepSize, ValueRange vals) { - // First we need to remove any dependence on the loop index from the affine - // map - SmallVector dims; - - for (auto idx : indices) { - // This tracks the index corresponding to the for loop if present in - // load/store operands else it's -1 - ssize_t dim_idx = -1; - // To check if induction variable of for loop in an operand of this op - // (load/store) - for (auto &&[i, v] : llvm::enumerate(vals)) { - if (v == idx) { - // Offset we're replacing must be an index (not a symbol). - // If we guarantee to run AffineCFG first, this should always be true. - assert(i < oldmap.getNumDims()); - // There should only be one use of the index. - assert(dim_idx == -1); - dim_idx = i; - continue; - } - } - if (dim_idx != -1 && !isLinearInIndex(oldmap, dim_idx)) { - legal = false; - return {val, oldmap}; - } - dims.push_back(dim_idx); - } +Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, + Value val, Value index, Value bound, int firstNDims, ValueRange vals) { SmallVector vals_without_indices; - for (auto v : vals) { - if (!llvm::is_contained(indices, v)) - vals_without_indices.push_back(v); - } - - // Evaluate offsets as oldmap replacing all indices with 0, and evaluating at the - // remaining variables - AffineMap offsetMap = oldmap; - for (auto dim_idx : dims) { - if (dim_idx != -1) { - offsetMap = - oldmap.replace(builder.getAffineDimExpr(dim_idx), - builder.getAffineConstantExpr(loopLowerBound), - offsetMap.getNumDims(), offsetMap.getNumSymbols()); - offsetMap = shiftDimsDown1(offsetMap, oldmap.getNumDims(), dim_idx); - } - } - - SmallVector strideMaps; - - // For each dimension `outer_dim_idx` we want to keep, - // create a new affine map equal to the map(dim=1, other dims=0) - for (auto outer_dim_idx : dims) { - AffineMap strideMap = oldmap; - if (outer_dim_idx != -1) { - strideMap = oldmap.replace( - builder.getAffineDimExpr(outer_dim_idx), - builder.getAffineConstantExpr(loopLowerBound + loopStepSize), - strideMap.getNumDims(), strideMap.getNumSymbols()); - strideMap = shiftDimsDown1(strideMap, oldmap.getNumDims(), outer_dim_idx); - } - for (auto dim_idx : dims) { - if (dim_idx == outer_dim_idx || dim_idx == -1) continue; - - offsetMap = - oldmap.replace(builder.getAffineDimExpr(dim_idx), - builder.getAffineConstantExpr(loopLowerBound), - offsetMap.getNumDims(), offsetMap.getNumSymbols()); - offsetMap = shiftDimsDown1(offsetMap, oldmap.getNumDims(), dim_idx); - } - - // Subtracting maps of stride and offset, gives you the offset value in the - // result of the map - SmallVector subtracts; - for (auto &&[lhs, rhs] : - llvm::zip(strideMap.getResults(), offsetMap.getResults())) { - subtracts.push_back(lhs - rhs); - } - strideMap = - AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), - subtracts, builder.getContext()); - strideMaps.push_back(strideMap); - } - - - // List of starting offsets into the subview - SmallVector offsets; - for (auto &&[expr, offset_expr] : llvm::zip(oldmap.getResults(), offsetMap.getResults())) { - offsets.push_back(builder.create( - val.getLoc(), - AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), - offset_expr, builder.getContext()), - vals_without_indices)); // What is there are symbols in the expression? + ssize_t dimidx = -1; + for (auto [i, v] : llvm::enumerate(vals)) { + if (v != index) + vals_without_indices.push_back(v); + else + dimidx = i; } - SmallVector sizes; - SmallVector strides; - - // Expression to index into the generated subview given the loop index - SmallVector loop_idxs; - SmallVector sizes; - for (auto &&[dim_idx, idx_size] : llvm::zip(dims, idx_sizes)) { - if (!oldmap.isFunctionOfDim(dim_idx)) { - loop_idxs.push_back(builder.getAffineConstantExpr(0)); - sizes.push_back(builder.create(val.getLoc(), 1)); + SmallVector dimReplacements; + size_t validx = 0; + for (int i=0; i( - val.getLoc(), - AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), - stride_expr, builder.getContext()), - vals_without_indices)); // What is there are symbols in the expression? - - // These need to be properly computed - // This is the remainign hard part to factor - if (!expr.isFunctionOfDim(dim_idx)) { - sizes.push_back(builder.create(val.getLoc(), 1)); - } else { - sizes.push_back(idx_size); - } + SmallVector symReplacements; + for (int i=0; i idx_sizes; + for (size_t i=0; i(val.getLoc(), val, i)); + } + idx_sizes.push_back(bound); - auto newval = builder.create(val.getLoc(), val, remap, vals_without_indices, sizes); legal = true; - // Does this need fix? Here we are constraining to dims as 1 and symbols as 0, - // should it be, original - return {newval, AffineMap::get(/*dims*/ 1, /*symbols*/ 0, loop_idxs, - builder.getContext())}; + SmallVector sizes(idx_sizes.size(), -1); + for (auto sz : idx_sizes) + vals_without_indices.push_back(sz); + auto ty = MemRefType::get(sizes, cast(val.getType()).getElementType()); + return builder.create(val.getLoc(), ty, val, vals_without_indices, map2); } // store A[...] @@ -365,6 +267,31 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, if (!loop->isAncestor(defOp)) continue; + if (auto SM = dyn_cast(defOp)) { + auto submap = SM.getMap(); + + auto composeMap = submap.compose(lgMap); + + SmallVector operands0; + + // First the dims + for (size_t i = 0; i < lgMap.getNumDims(); i++) + operands0.push_back(lgOperands[i]); + + // Then the symbols of submap + for (size_t i = 0; i < submap.getNumSymbols(); i++) + operands0.push_back(SM.getSymbols()[i]); + + // Then the symbols of lgMap + for (size_t i = 0; i < lgMap.getNumSymbols(); i++) + operands0.push_back(lgOperands[i + lgMap.getNumDims()]); + + lgMap = composeMap; + lgOperands = operands0; + input = SM.getMemref(); + continue; + } + if (auto SV = dyn_cast(defOp)) { // TODO update map with the new indexing from here @@ -638,6 +565,7 @@ struct AffineForOpRaising : public OpRewritePattern { // loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), // *ub, *lb); + for (auto &&[conds, lg] : linalgGenerics) { // This captures the indexing map attribute from the linalg.generic being @@ -654,7 +582,9 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO: Implement this // lgMap comes from offset of memref.subview, // lgOperands comes from operands of memref.subview - AffineMap lgMap = cast(indexingMapsAttr[idx]).getAffineMap(); + + const AffineMap lgMap0 = cast(indexingMapsAttr[idx]).getAffineMap(); + AffineMap lgMap = lgMap0; SmallVector lgOperands; lgOperands.push_back(input); // for (auto i = 0; i < lgMap.getNumDims(); i++) @@ -672,7 +602,7 @@ struct AffineForOpRaising : public OpRewritePattern { return failure(); bool legal = true; - + // Takes input's/output's, affineMap of load/store (here lgMap ?), // induction variable corresponding to the loop // Memref corresponding the the memory accessed (in this case subview ?) @@ -686,13 +616,17 @@ struct AffineForOpRaising : public OpRewritePattern { // For this case: LG within ForOp - // Inputs should be : load map extracted from subviewOp // Returns LG with indexingMap and subview with affine.apply - which are correct - auto &&[newMemref, newAffineMap] = remap_in_affine_dim( + size_t firstNDims = lgMap.getResults().size(); + auto newMemref = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, - lbConst.getValue(), step, lgOperands); + firstNDims, lgOperands); + if (!legal) return failure(); + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); + // TODO: need to mergre previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); @@ -705,7 +639,8 @@ struct AffineForOpRaising : public OpRewritePattern { if (conds.size() != 0) return failure(); - AffineMap lgMap = cast(indexingMapsAttr[idx]).getAffineMap(); + const AffineMap lgMap0 = cast(indexingMapsAttr[idx]).getAffineMap(); + AffineMap lgMap = lgMap0; SmallVector lgOperands; lgOperands.push_back(output); // for (auto i = 0; i < lgMap.getNumDims(); i++) @@ -719,13 +654,14 @@ struct AffineForOpRaising : public OpRewritePattern { bool legal = true; - auto &&[newMemref, newAffineMap] = remap_in_affine_dim( - legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, - lbConst.getValue(), step, lgOperands); + size_t firstNDims = lgMap.getResults().size(); + auto newMemref = remap_in_affine_dim( + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, lgOperands); if (!legal) return failure(); + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); // TODO: need to merge previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); @@ -743,17 +679,18 @@ struct AffineForOpRaising : public OpRewritePattern { continue; } + size_t firstNDims = 0; bool legal = true; - auto &&[newMemref, newAffineMap] = remap_in_affine_dim( + auto newMemref = remap_in_affine_dim( legal, rewriter, load.getAffineMap(), load.getMemref(), - loop.getInductionVar(), loopSize, lbConst.getValue(), step, + loop.getInductionVar(), loopSize, firstNDims, load.getMapOperands()); if (!legal) return failure(); - affineMaps.push_back(newAffineMap); + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); inputs.push_back(newMemref); } // TODO Push all of the inputs to the linalg generics (modifying maps as @@ -769,14 +706,17 @@ struct AffineForOpRaising : public OpRewritePattern { bool legal = true; - auto &&[newMemref, newAffineMap] = remap_in_affine_dim( + size_t firstNDims = 0; + + auto newMemref = remap_in_affine_dim( legal, rewriter, store.getAffineMap(), store.getMemref(), - loop.getInductionVar(), loopSize, lbConst.getValue(), step, + loop.getInductionVar(), loopSize, firstNDims, store.getMapOperands()); if (!legal) return failure(); + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); affineMaps.push_back(newAffineMap); outputs.push_back(newMemref); } From 9018d9288c2aa6e8c5b2651c8fadea2d92fc555f Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 30 Jul 2024 18:16:42 +0000 Subject: [PATCH 13/77] bunch of fixes. Now able to generate raise linalg code --- include/polygeist/Passes/Passes.td | 1 + include/polygeist/PolygeistOps.td | 17 ++ lib/polygeist/Passes/RaiseToLinalg.cpp | 233 +++++++++++++------------ 3 files changed, 137 insertions(+), 114 deletions(-) diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 5c17a9d6dc25..fc5b36aa9caf 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -157,6 +157,7 @@ def AffineRaiseToLinalg : Pass<"raise-affine-to-linalg"> { let dependentDialects = [ "affine::AffineDialect", "linalg::LinalgDialect", + "polygeist::PolygeistDialect", ]; } diff --git a/include/polygeist/PolygeistOps.td b/include/polygeist/PolygeistOps.td index 159f6c144947..0d4b5c01727d 100644 --- a/include/polygeist/PolygeistOps.td +++ b/include/polygeist/PolygeistOps.td @@ -259,4 +259,21 @@ def TypeAlignOp : Polygeist_Op<"typeAlign", [Pure]> { let hasFolder = 1; let hasCanonicalizer = 1; } + +def SubmapOp : Polygeist_Op<"submap", [Pure, ViewLikeOpInterface]> { + let arguments = (ins Arg:$memref, + Variadic:$indices_and_sizes, + AffineMapAttr:$map + ); + let results = (outs AnyMemRef : $result); + let hasFolder = 1; + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + ::mlir::ValueRange getSymbols() { return getOperands().slice(0, getMap().getNumSymbols()); } + ::mlir::ValueRange getSizes() { return getOperands().slice(getMap().getNumSymbols(), getMap().getNumSymbols() + getType().getShape().size()); } + ::mlir::Value getViewSource() { return getMemref(); } + }]; +} + #endif // POLYGEIST_OPS diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 59b443f71835..0be52d285f29 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -119,13 +119,14 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { // original map. Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, - Value val, Value index, Value bound, int firstNDims, ValueRange vals) { - - SmallVector vals_without_indices; + Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands) { + + //Operands which don't correspond to indices + SmallVector operands_without_indices; ssize_t dimidx = -1; - for (auto [i, v] : llvm::enumerate(vals)) { + for (auto [i, v] : llvm::enumerate(oldmap_operands)) { if (v != index) - vals_without_indices.push_back(v); + operands_without_indices.push_back(v); else dimidx = i; } @@ -139,6 +140,7 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, } else if (i == dimidx) { dimReplacements.push_back(builder.getAffineDimExpr(dimReplacements.size())); } else { + // TODO: Why are we using symbol here instead of dim? dimReplacements.push_back(builder.getAffineSymbolExpr(validx)); validx++; } @@ -149,21 +151,23 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, symReplacements.push_back(builder.getAffineSymbolExpr(validx)); validx++; } - assert(validx == vals_without_indices.size()); - auto map2 = oldmap.replaceDimsAndSymbols(dimReplacements, symReplacements, firstNDims+1, vals_without_indices.size()); + assert(validx == operands_without_indices.size()); + auto map2 = oldmap.replaceDimsAndSymbols(dimReplacements, symReplacements, firstNDims+1, operands_without_indices.size()); SmallVector idx_sizes; for (size_t i=0; i(val.getLoc(), val, i)); + idx_sizes.push_back(builder.create(memref_val.getLoc(), memref_val, i)); } idx_sizes.push_back(bound); legal = true; - SmallVector sizes(idx_sizes.size(), -1); + // TODO: Cannot be negative size, are we trying to initialize it with any size, or do we want to calcualte size from + // loop bounds? + SmallVector sizes(idx_sizes.size(), 1); for (auto sz : idx_sizes) - vals_without_indices.push_back(sz); - auto ty = MemRefType::get(sizes, cast(val.getType()).getElementType()); - return builder.create(val.getLoc(), ty, val, vals_without_indices, map2); + operands_without_indices.push_back(sz); + auto ty = MemRefType::get(sizes, cast(memref_val.getType()).getElementType()); + return builder.create(memref_val.getLoc(), ty, memref_val, operands_without_indices, map2); } // store A[...] @@ -292,108 +296,108 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, continue; } - if (auto SV = dyn_cast(defOp)) { - - // TODO update map with the new indexing from here - - // Create affine map - // i. Track number of running dims and symbols - // ii. shift dims and symbols to generate shifted expressions. - // Extract corresponding operands - // Use affineMap::get with numOperands and numSymbols along with shifted - // expressions to get a map. Use affine map simplify to simplify this - - SmallVector startExprs; - SmallVector strideExprs; - SmallVector dimOperands; - SmallVector symOperands; - for (auto &&[first, second] : llvm::zip(SV.getOffsets(), SV.getStrides())) { - for (auto &&[index, val] : llvm::enumerate(SmallVector({first, second}))) { - auto &exprOutput = (index == 0) ? startExprs : strideExprs; - // Only support constants, symbols, or affine apply as offsets - if (auto cop = val.getDefiningOp()) { - exprOutput.push_back(builder.getAffineConstantExpr(cop.value())); - continue; - } else if (auto cop = val.getDefiningOp()) { - exprOutput.push_back(builder.getAffineConstantExpr(cop.value())); - continue; - } - if (auto ba = dyn_cast(val)) { - Block *parentBlock = ba.getOwner(); - if (isa(parentBlock->getParentOp())) { - exprOutput.push_back( - builder.getAffineDimExpr(dimOperands.size())); - dimOperands.push_back(ba); - continue; - - } - } - - auto valOp = val.getDefiningOp(); - // Defined outside loop, consider it a symbol [for now] - //if (!valOp || loop->isAncestor(defOp)) { - if (valOp&&!loop->isAncestor(defOp)) { - exprOutput.push_back( - builder.getAffineSymbolExpr(symOperands.size())); - symOperands.push_back(val); - continue; - } - - //TODO: Maybe it's a case to add, but are we sure we need it for starts and offsets - // and not for operands - if (auto apply = dyn_cast(valOp)) { - auto map = apply.getAffineMap(); - auto *scope = affine::getAffineScope(valOp)->getParentOp(); - DominanceInfo DI(scope); - auto map_operands = apply.getOperands(); - //fully2ComposeAffineMapAndOperands(builder, &map, &map_operands, DI); - // Instead of using loop step we are using 1 (Assumption as the stride size) - auto newexpr = map.shiftDims(dimOperands.size()) - .shiftSymbols(symOperands.size()); - - for (auto expr : newexpr.getResults()) { - exprOutput.push_back(expr); - } - - for (size_t i = 0; i < map.getNumDims(); i++) - dimOperands.push_back(apply.getOperands()[i]); - - for (size_t i = 0; i < map.getNumSymbols(); i++) - symOperands.push_back(apply.getOperands()[i + map.getNumDims()]); - - continue; - } - - //return failure(); - } - } - - SmallVector inputExprs; - for (auto expr : lgMap.shiftDims(dimOperands.size()) - .shiftSymbols(symOperands.size()).getResults()) { - inputExprs.push_back(expr); - } - for (size_t i = 0; i < lgMap.getNumDims(); i++) - dimOperands.push_back(lgOperands[i]); - - for (size_t i = 0; i < lgMap.getNumSymbols(); i++) - symOperands.push_back(lgOperands[i + lgMap.getNumDims()]); - - - SmallVector mergedExprs; - for (auto && [start, stride, idx] : - llvm::zip(startExprs, strideExprs, inputExprs)) { - mergedExprs.push_back(start + idx * stride); - } - - lgMap = - AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs, loop->getContext()); - lgOperands.clear(); - lgOperands.insert(lgOperands.begin(), dimOperands.begin(), dimOperands.end()); - lgOperands.insert(lgOperands.begin()+lgOperands.size(), symOperands.begin(), symOperands.end()); - input = SV.getSource(); - break; - } + //if (auto SV = dyn_cast(defOp)) { + + // // TODO update map with the new indexing from here + + // // Create affine map + // // i. Track number of running dims and symbols + // // ii. shift dims and symbols to generate shifted expressions. + // // Extract corresponding operands + // // Use affineMap::get with numOperands and numSymbols along with shifted + // // expressions to get a map. Use affine map simplify to simplify this + + // SmallVector startExprs; + // SmallVector strideExprs; + // SmallVector dimOperands; + // SmallVector symOperands; + // for (auto &&[first, second] : llvm::zip(SV.getOffsets(), SV.getStrides())) { + // for (auto &&[index, val] : llvm::enumerate(SmallVector({first, second}))) { + // auto &exprOutput = (index == 0) ? startExprs : strideExprs; + // // Only support constants, symbols, or affine apply as offsets + // if (auto cop = val.getDefiningOp()) { + // exprOutput.push_back(builder.getAffineConstantExpr(cop.value())); + // continue; + // } else if (auto cop = val.getDefiningOp()) { + // exprOutput.push_back(builder.getAffineConstantExpr(cop.value())); + // continue; + // } + // if (auto ba = dyn_cast(val)) { + // Block *parentBlock = ba.getOwner(); + // if (isa(parentBlock->getParentOp())) { + // exprOutput.push_back( + // builder.getAffineDimExpr(dimOperands.size())); + // dimOperands.push_back(ba); + // continue; + + // } + // } + + // auto valOp = val.getDefiningOp(); + // // Defined outside loop, consider it a symbol [for now] + // //if (!valOp || loop->isAncestor(defOp)) { + // if (valOp&&!loop->isAncestor(defOp)) { + // exprOutput.push_back( + // builder.getAffineSymbolExpr(symOperands.size())); + // symOperands.push_back(val); + // continue; + // } + + // //TODO: Maybe it's a case to add, but are we sure we need it for starts and offsets + // // and not for operands + // if (auto apply = dyn_cast(valOp)) { + // auto map = apply.getAffineMap(); + // auto *scope = affine::getAffineScope(valOp)->getParentOp(); + // DominanceInfo DI(scope); + // auto map_operands = apply.getOperands(); + // //fully2ComposeAffineMapAndOperands(builder, &map, &map_operands, DI); + //// Instead of using loop step we are using 1 (Assumption as the stride size) + // auto newexpr = map.shiftDims(dimOperands.size()) + // .shiftSymbols(symOperands.size()); + + // for (auto expr : newexpr.getResults()) { + // exprOutput.push_back(expr); + // } + + // for (size_t i = 0; i < map.getNumDims(); i++) + // dimOperands.push_back(apply.getOperands()[i]); + + // for (size_t i = 0; i < map.getNumSymbols(); i++) + // symOperands.push_back(apply.getOperands()[i + map.getNumDims()]); + + // continue; + // } + + // //return failure(); + // } + // } + + // SmallVector inputExprs; + // for (auto expr : lgMap.shiftDims(dimOperands.size()) + // .shiftSymbols(symOperands.size()).getResults()) { + // inputExprs.push_back(expr); + // } + // for (size_t i = 0; i < lgMap.getNumDims(); i++) + // dimOperands.push_back(lgOperands[i]); + + // for (size_t i = 0; i < lgMap.getNumSymbols(); i++) + // symOperands.push_back(lgOperands[i + lgMap.getNumDims()]); + + + // SmallVector mergedExprs; + // for (auto && [start, stride, idx] : + // llvm::zip(startExprs, strideExprs, inputExprs)) { + // mergedExprs.push_back(start + idx * stride); + // } + + // lgMap = + // AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs, loop->getContext()); + // lgOperands.clear(); + // lgOperands.insert(lgOperands.begin(), dimOperands.begin(), dimOperands.end()); + // lgOperands.insert(lgOperands.begin()+lgOperands.size(), symOperands.begin(), symOperands.end()); + // input = SV.getSource(); + // break; + //} //return failure(); } @@ -691,6 +695,7 @@ struct AffineForOpRaising : public OpRewritePattern { return failure(); auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); + affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); } // TODO Push all of the inputs to the linalg generics (modifying maps as From ec041a0686942723f5d5019ba4a400231d9bde1b Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 31 Jul 2024 06:38:32 +0000 Subject: [PATCH 14/77] Now almost working second loop raising to linalg --- include/polygeist/PolygeistOps.td | 4 +- lib/polygeist/Passes/RaiseToLinalg.cpp | 94 +++++++++++++++++++++----- 2 files changed, 80 insertions(+), 18 deletions(-) diff --git a/include/polygeist/PolygeistOps.td b/include/polygeist/PolygeistOps.td index 0d4b5c01727d..aeac713fdc9b 100644 --- a/include/polygeist/PolygeistOps.td +++ b/include/polygeist/PolygeistOps.td @@ -270,8 +270,8 @@ def SubmapOp : Polygeist_Op<"submap", [Pure, ViewLikeOpInterface]> { let hasCanonicalizer = 1; let extraClassDeclaration = [{ - ::mlir::ValueRange getSymbols() { return getOperands().slice(0, getMap().getNumSymbols()); } - ::mlir::ValueRange getSizes() { return getOperands().slice(getMap().getNumSymbols(), getMap().getNumSymbols() + getType().getShape().size()); } + ::mlir::ValueRange getSymbols() { return getOperands().slice(1, getMap().getNumSymbols()+1); } + ::mlir::ValueRange getSizes() { return getOperands().slice(getMap().getNumSymbols()+1, getMap().getNumSymbols() + getType().getShape().size()+1); } ::mlir::Value getViewSource() { return getMemref(); } }]; } diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 0be52d285f29..d5b8d8d1a23b 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -120,11 +120,16 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands) { - + assert(oldmap_operands.size() == oldmap.getNumSymbols() + oldmap.getNumDims()); //Operands which don't correspond to indices SmallVector operands_without_indices; ssize_t dimidx = -1; for (auto [i, v] : llvm::enumerate(oldmap_operands)) { + if (v == nullptr) { + assert(i < firstNDims); + continue; + } + assert(i >= firstNDims); if (v != index) operands_without_indices.push_back(v); else @@ -148,8 +153,30 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, SmallVector symReplacements; for (int i=0; i sizes(idx_sizes.size(), 1); + SmallVector sizes(idx_sizes.size(), mlir::ShapedType::kDynamic); for (auto sz : idx_sizes) operands_without_indices.push_back(sz); + // memref auto ty = MemRefType::get(sizes, cast(memref_val.getType()).getElementType()); return builder.create(memref_val.getLoc(), ty, memref_val, operands_without_indices, map2); } @@ -267,9 +295,10 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, while (Operation *defOp = input.getDefiningOp()) { + assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); // If the input is defined outside of the loop, we are finished. if (!loop->isAncestor(defOp)) - continue; + break; if (auto SM = dyn_cast(defOp)) { auto submap = SM.getMap(); @@ -293,6 +322,7 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, lgMap = composeMap; lgOperands = operands0; input = SM.getMemref(); + assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); continue; } @@ -401,6 +431,7 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, //return failure(); } + assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); return success(); } @@ -441,6 +472,7 @@ struct AffineForOpRaising : public OpRewritePattern { while (cur != loop) { auto ifstmt = dyn_cast(cur); if (!ifstmt) { + llvm::errs() << "internal cur which prevents hoising: " << *cur << "\n"; return WalkResult::interrupt(); } bool ifTrue = @@ -462,6 +494,7 @@ struct AffineForOpRaising : public OpRewritePattern { if (isReadNone(op)) { return WalkResult::advance(); } + llvm::errs() << "internal op which prevents hoising: " << *op << "\n"; return WalkResult::interrupt(); }); @@ -590,9 +623,10 @@ struct AffineForOpRaising : public OpRewritePattern { const AffineMap lgMap0 = cast(indexingMapsAttr[idx]).getAffineMap(); AffineMap lgMap = lgMap0; SmallVector lgOperands; - lgOperands.push_back(input); - // for (auto i = 0; i < lgMap.getNumDims(); i++) - // lgOperands.push_back(lgMap.getOperands()[i]); + for (int i=0; i { // lgOperands contains current input (i.e probably a subview) // Gives output ... + + assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); auto result = getLinalgArgMap(loop, lgMemref, lgMap, lgOperands); if (!result.succeeded()) @@ -623,7 +659,7 @@ struct AffineForOpRaising : public OpRewritePattern { size_t firstNDims = lgMap.getResults().size(); auto newMemref = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, - firstNDims, lgOperands); + firstNDims, ValueRange(lgOperands)); if (!legal) @@ -645,12 +681,13 @@ struct AffineForOpRaising : public OpRewritePattern { const AffineMap lgMap0 = cast(indexingMapsAttr[idx]).getAffineMap(); AffineMap lgMap = lgMap0; + SmallVector lgOperands; - lgOperands.push_back(output); - // for (auto i = 0; i < lgMap.getNumDims(); i++) - // lgOperands.push_back(lgMap.getSubMap(i)); + for (int i=0; i { size_t firstNDims = lgMap.getResults().size(); auto newMemref = remap_in_affine_dim( - legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, lgOperands); + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands)); if (!legal) return failure(); @@ -729,7 +766,7 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO presently if linalg generic exists, assert there are no load/stores if ((linalgGenerics.size() > 0) && - ((loads.size() == 0) && (stores.size() == 0))) + ((loads.size() != 0) || (stores.size() != 0))) return failure(); // TODO assert only zero or one linalg generic exists @@ -739,6 +776,13 @@ struct AffineForOpRaising : public OpRewritePattern { SmallVector iteratorTypes; // TODO if linalg generic exists, make this iterator type prepend to the // existing iterators + + if (linalgGenerics.size() == 1) { + for (auto attr : linalgGenerics[0].second.getIteratorTypesArray()) + iteratorTypes.push_back(utils::IteratorType::parallel); + } + + // TODO determine if linalg generic, whether to create parallel or reduction by looking at memory patterns of maps iteratorTypes.push_back((stores_map.size() == 0) ? utils::IteratorType::parallel : utils::IteratorType::reduction); @@ -772,7 +816,6 @@ struct AffineForOpRaising : public OpRewritePattern { auto arg = blk->addArgument(load.getType(), load.getLoc()); rewriter.replaceOp(load, arg); } - for (auto &&[conds, store] : stores) { auto arg = blk->addArgument(store.getValueToStore().getType(), store.getLoc()); @@ -793,6 +836,25 @@ struct AffineForOpRaising : public OpRewritePattern { SmallVector toreturn; + for (auto genPair : linalgGenerics) { + auto genOp = genPair.second; + auto &genBlock = genOp->getRegion(0).front(); + auto term = genBlock.getTerminator(); + mlir::IRMapping map; + for (auto arg : genBlock.getArguments()) { + auto arg2 = + blk->addArgument(arg.getType(), arg.getLoc()); + map.map(arg, arg2); + } + for (auto &op : genBlock.without_terminator()) { + rewriter.clone(op, map); + } + for (auto op : term->getOperands()) { + toreturn.push_back(map.lookup(op)); + } + rewriter.eraseOp(genOp); + } + for (auto &&[conds, store] : stores) { toreturn.push_back(store.getValueToStore()); rewriter.eraseOp(store); From 23138fc4c08647df19b7b5046aad45038219ddb3 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 31 Jul 2024 18:39:41 +0000 Subject: [PATCH 15/77] Fixes to correctly raise 2 level for loops to linalg.generic --- lib/polygeist/Passes/RaiseToLinalg.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index d5b8d8d1a23b..4b1b96518617 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -472,7 +472,6 @@ struct AffineForOpRaising : public OpRewritePattern { while (cur != loop) { auto ifstmt = dyn_cast(cur); if (!ifstmt) { - llvm::errs() << "internal cur which prevents hoising: " << *cur << "\n"; return WalkResult::interrupt(); } bool ifTrue = @@ -494,7 +493,6 @@ struct AffineForOpRaising : public OpRewritePattern { if (isReadNone(op)) { return WalkResult::advance(); } - llvm::errs() << "internal op which prevents hoising: " << *op << "\n"; return WalkResult::interrupt(); }); @@ -539,7 +537,7 @@ struct AffineForOpRaising : public OpRewritePattern { // We're going to need to upgrade the defn of mayAlias for subviews (aka // mayAlias(subview, x) -> mayAlias(operand(subview), x)) - SmallVector inputs; + SmallVector inputs, outputs; SmallVector affineMaps; SmallVector indexingMaps; @@ -705,7 +703,7 @@ struct AffineForOpRaising : public OpRewritePattern { auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); // TODO: need to merge previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); - inputs.push_back(newMemref); + outputs.push_back(newMemref); } } @@ -738,7 +736,7 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO Push all of the inputs to the linalg generics (modifying maps as // needed) - SmallVector outputs; + //SmallVector outputs; // Store we may need to reindex into a splat potentially later, but for now // we'll be lazy for (auto &&[conds, store] : stores) { From 5f20bd7877cd18c09037e7055cb352cceef3327e Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 31 Jul 2024 18:41:22 +0000 Subject: [PATCH 16/77] Missed file update to enable linalg dialect in polygeist --- tools/polygeist-opt/polygeist-opt.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 95fe1b1fc4a4..64a7e7a35293 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -59,6 +60,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); From b0e96aadf6b47c87eb8a7826d4135863167f6ea3 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 6 Aug 2024 23:59:56 +0000 Subject: [PATCH 17/77] Fix for syms and dims calculation --- lib/polygeist/Passes/RaiseToLinalg.cpp | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 4b1b96518617..dfe282fb05b4 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -137,30 +137,34 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, } SmallVector dimReplacements; - size_t validx = 0; + size_t validSims = 0; + size_t validDims = 0; for (int i=0; i symReplacements; for (int i=0; i idx_sizes; From ea76f0a4cb18a509b862f1ada3de621b479855fa Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 7 Aug 2024 02:07:29 +0000 Subject: [PATCH 18/77] More tests added to cover different loop cases --- test/polygeist-opt/linalgraise.mlir | 311 ++++++++++++++++++++++++---- 1 file changed, 275 insertions(+), 36 deletions(-) diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index 27b0a843dddb..c28d8662732f 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -1,6 +1,20 @@ -//// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s -// +////// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s +//// //module { +// func.func @main0(%12 : i1, %18 : memref<32xf32> ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %19 = memref.alloca() : memref<32xf32> +// scf.if %12 { +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref<32xf32> +// affine.store %ld, %19[%arg4] : memref<32xf32> +// } +// } +// return +// } +// // func.func @main(%12 : i1, %14 : i32, %18 : memref ) { // %c0 = arith.constant 0 : index // %c4 = arith.constant 4 : index @@ -10,7 +24,7 @@ // %17 = arith.divui %16, %c4 : index // %19 = memref.alloca(%17) : memref // scf.if %12 { -// affine.for %arg4 = 0 to %17 { +// affine.for %arg4 = 0 to 17 { // %ld = affine.load %18[%arg4] : memref // affine.store %ld, %19[%arg4] : memref // } @@ -177,7 +191,7 @@ // } //} // -////reduction +////TODO: reduction //module @reduction{ // func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { // %c0 = arith.constant 0 : index @@ -198,7 +212,7 @@ // } //} // -////Conditional store-1 +////TODO: Conditional store-1 //module @cond_store_1 { // func.func @main(%12 : i1, %14 : i32, %18 : memref ) { // %c0 = arith.constant 0 : index @@ -219,7 +233,7 @@ // } //} // -////Conditional store-2 +////TODO: Conditional store-2 //module @cond_store_2{ // func.func @main(%12 : i1, %14 : i32, %18 : memref ) { // %c0 = arith.constant 0 : index @@ -267,8 +281,8 @@ // return // } //} -// -////Fors inside for + +//////Fors inside for //module @for_within_for{ // func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { // %c0 = arith.constant 0 : index @@ -291,6 +305,231 @@ // } //} // +////Fors inside for +//module @for_within_for_2{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3+2*%arg4] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////Fors inside for +//module @for_within_for_3{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3+2*%arg4] : memref +// %ld2 = affine.load %18[%arg3] : memref +// %ld3 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// %mul2 = arith.mulf %mul, %ld3 : f32 +// affine.store %mul2, %19[%arg4] : memref +// } +// } +// return +// } +//} + +////Fors inside for +//module @for_within_for_3{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg4+2*%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +//Fors inside for +module @for_3_levels_0{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + } + return + } +} + +////Fors inside for +//module @for_3_levels_1{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg5 = 0 to 21 { +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// } +// return +// } +//} + +////Fors inside for +//module @for_3_levels_1{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// affine.for %arg5 = 0 to 21 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %ld3 = affine.load %23[%arg5] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// %mul2 = arith.mulf %mul, %ld3 : f32 +// affine.store %mul2, %19[%arg4] : memref +// } +// } +// } +// return +// } +//} + +////Fors inside for +//module @for_3_levels_2{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// affine.for %arg5 = 0 to 21 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %ld3 = affine.load %20[%arg5] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// %mul2 = arith.mulf %mul, %ld3 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// } +// return +// } +//} + +//Fors inside for +//module @for_3_levels_2{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// affine.for %arg5 = 0 to 21 { +// %ld1 = affine.load %18[%arg3+%arg4] : memref +// %ld2 = affine.load %20[%arg4+%arg5] : memref +// %ld3 = affine.load %20[%arg5+%arg3] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// %mul2 = arith.mulf %mul, %ld3 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// } +// return +// } +//} + +//#map = affine_map<(d0)[s0] -> (s0)> +//#map1 = affine_map<(d0) -> (d0)> +//module @for_within_for2 { +// func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { +// %c17 = arith.constant 17 : index +// %c4 = arith.constant 4 : index +// %0 = arith.index_cast %arg1 : i32 to index +// %1 = arith.muli %0, %c4 : index +// %2 = arith.divui %1, %c4 : index +// %alloca = memref.alloca(%2) : memref +// affine.for %arg4 = 0 to 21 { +// %3 = "polygeist.submap"(%arg2, %arg4, %c17) <{map = #map}> : (memref, index, index) -> memref +// %4 = "polygeist.submap"(%arg3, %c17) <{map = #map1}> : (memref, index) -> memref +// %5 = "polygeist.submap"(%alloca, %c17) <{map = #map1}> : (memref, index) -> memref +// linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { +// ^bb0(%in: f32, %in_0: f32, %out: f32): +// %6 = arith.mulf %in, %in_0 : f32 +// linalg.yield %6 : f32 +// } +// } +// return +// } +//} + ////Parallel fors inside for //module @parallel_fors_inside_for { // func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { @@ -401,32 +640,32 @@ // return %c0_i32 : i32 // } //} - -//conv (direct store) -module @conv_2 { - memref.global @out : memref<512x64xi32> = uninitialized - memref.global @filter : memref<4x4xi32> = uninitialized - memref.global @im : memref<515x67xi32> = uninitialized - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im : memref<515x67xi32> - %1 = memref.get_global @filter : memref<4x4xi32> - %2 = memref.get_global @out : memref<512x64xi32> - affine.for %arg0 = 0 to 512 { - affine.for %arg1 = 0 to 64 { - affine.for %arg2 = 0 to 4 { - affine.for %arg3 = 0 to 4 { - %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> - %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> - %5 = arith.muli %3, %4 : i32 - %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> - %7 = arith.addi %6, %5 : i32 - affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> - } - } - } - } - return %c0_i32 : i32 - } -} +// +////conv (direct store) +//module @conv_2 { +// memref.global @out : memref<512x64xi32> = uninitialized +// memref.global @filter : memref<4x4xi32> = uninitialized +// memref.global @im : memref<515x67xi32> = uninitialized +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im : memref<515x67xi32> +// %1 = memref.get_global @filter : memref<4x4xi32> +// %2 = memref.get_global @out : memref<512x64xi32> +// affine.for %arg0 = 0 to 512 { +// affine.for %arg1 = 0 to 64 { +// affine.for %arg2 = 0 to 4 { +// affine.for %arg3 = 0 to 4 { +// %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> +// %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> +// %5 = arith.muli %3, %4 : i32 +// %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> +// %7 = arith.addi %6, %5 : i32 +// affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> +// } +// } +// } +// } +// return %c0_i32 : i32 +// } +//} \ No newline at end of file From 591c84ea5559854c8e69f2df9a258adbf92be059 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 7 Aug 2024 06:36:42 +0000 Subject: [PATCH 19/77] Now able to compile 3/any number of loops with parallel iter type; Added extra tests in lit test --- lib/polygeist/Passes/RaiseToLinalg.cpp | 10 +- test/polygeist-opt/linalgraise.mlir | 1098 ++++++++++++------------ 2 files changed, 575 insertions(+), 533 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index dfe282fb05b4..b6668e57ee70 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -187,13 +187,12 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, SmallVector idx_sizes; for (size_t i=0; i(memref_val.getLoc(), memref_val, i)); } idx_sizes.push_back(bound); legal = true; - // TODO: Cannot be negative size, are we trying to initialize it with any size, or do we want to calcualte size from - // loop bounds? SmallVector sizes(idx_sizes.size(), mlir::ShapedType::kDynamic); for (auto sz : idx_sizes) operands_without_indices.push_back(sz); @@ -658,7 +657,10 @@ struct AffineForOpRaising : public OpRewritePattern { // For this case: LG within ForOp - // Inputs should be : load map extracted from subviewOp // Returns LG with indexingMap and subview with affine.apply - which are correct - size_t firstNDims = lgMap.getResults().size(); + + //TODO: Or is it num dims? + //size_t firstNDims = lgMap.getResults().size(); + size_t firstNDims = lgMap.getNumDims(); auto newMemref = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands)); @@ -697,7 +699,7 @@ struct AffineForOpRaising : public OpRewritePattern { bool legal = true; - size_t firstNDims = lgMap.getResults().size(); + size_t firstNDims = lgMap.getNumDims(); auto newMemref = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands)); diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index c28d8662732f..627021c094a5 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -1,380 +1,419 @@ -////// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s -//// -//module { -// func.func @main0(%12 : i1, %18 : memref<32xf32> ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %19 = memref.alloca() : memref<32xf32> -// scf.if %12 { -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref<32xf32> -// affine.store %ld, %19[%arg4] : memref<32xf32> -// } -// } -// return -// } -// -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// scf.if %12 { -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// affine.store %ld, %19[%arg4] : memref -// } -// } -// return -// } -// -// -// func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// scf.if %12 { -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[3 * %arg4] : memref -// %ld2 = affine.load %18[0] : memref -// %fadd = arith.addf %ld, %ld2 : f32 -// affine.store %fadd, %19[%arg4 + 17] : memref -// } -// } -// return -// } -// -//} -// -//// CHECK: #map = affine_map<(d0) -> (d0)> -//// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { -//// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index -//// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index -//// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index -//// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index -//// CHECK-NEXT: scf.if %[[arg0]] { -//// TODO note that presently we do not ensure that the memrefs are sliced to the right size as the space requires -//// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg2 : memref) outs(%alloca : memref) { -//// CHECK-NEXT: ^bb0(%in: f32, %out: f32): -//// CHECK-NEXT: linalg.yield %in : f32 -//// CHECK-NEXT: } -//// CHECK-NEXT: } -//// CHECK-NEXT: } -// -////constant-access -//module @constant_access{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %ci324 = arith.constant 4.0 : f32 -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %mul = arith.mulf %ld, %ci324 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -////constant-mem-access -//module @constant_mem_access{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 4 to 17 step 2 { -// %ld = affine.load %18[3*%arg4] : memref -// %ld2 = affine.load %18[%c4] : memref -// %mul = arith.mulf %ld, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -////without-if -//module @no_if{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// affine.store %ld, %19[%arg4] : memref -// } -// return -// } -//} -// -////arith.mul -//module @arith_mul{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %mul = arith.mulf %ld, %ld : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -////arith.add -//module @arith_add{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg4] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %add = arith.addf %ld1, %ld2 : f32 -// %mul = arith.mulf %add, %add : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -////Conditional arith -//module @cond_arith{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %if = scf.if %12 -> f32 { -// %mul = arith.mulf %ld, %ld : f32 -// scf.yield %mul : f32 -// } else { -// scf.yield %ld : f32 -// } -// affine.store %if, %19[%arg4] : memref -// } -// return -// } -//} +//// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s // -////TODO: reduction -//module @reduction{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// %sum_0 = arith.constant 0.0 : f32 -// %red = affine.for %arg4 = 0 to 17 step 1 iter_args(%sum_iter = %sum_0) -> f32 { -// %ld1 = affine.load %18[%arg4] : memref -// %sum_next = arith.addf %sum_iter, %ld1 : f32 -// affine.yield %sum_next : f32 -// } -// affine.store %red, %19[0] : memref -// return -// } -//} -// -////TODO: Conditional store-1 -//module @cond_store_1 { -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %mul = arith.mulf %ld, %ld : f32 -// scf.if %12 { -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////TODO: Conditional store-2 -//module @cond_store_2{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// scf.if %12 { -// %mul = arith.mulf %ld, %ld : f32 -// affine.store %mul, %19[%arg4] : memref -// } else { -// affine.store %ld, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////Parallel for -//module @parallel_for{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %mul = arith.mulf %ld, %ld : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg4] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %add = arith.addf %ld1, %ld2 : f32 -// %mul = arith.mulf %add, %add : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} +module { + func.func @main0(%12 : i1, %18 : memref<32xf32> ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %19 = memref.alloca() : memref<32xf32> + scf.if %12 { + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref<32xf32> + affine.store %ld, %19[%arg4] : memref<32xf32> + } + } + return + } + + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + scf.if %12 { + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + affine.store %ld, %19[%arg4] : memref + } + } + return + } -//////Fors inside for -//module @for_within_for{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_within_for_2{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3+2*%arg4] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_within_for_3{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3+2*%arg4] : memref -// %ld2 = affine.load %18[%arg3] : memref -// %ld3 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// %mul2 = arith.mulf %mul, %ld3 : f32 -// affine.store %mul2, %19[%arg4] : memref -// } -// } -// return -// } -//} + + func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + scf.if %12 { + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[3 * %arg4] : memref + %ld2 = affine.load %18[0] : memref + %fadd = arith.addf %ld, %ld2 : f32 + affine.store %fadd, %19[%arg4 + 17] : memref + } + } + return + } + +} + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { +// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index +// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index +// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index +// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index +// CHECK-NEXT: scf.if %[[arg0]] { +// TODO note that presently we do not ensure that the memrefs are sliced to the right size as the space requires +// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg2 : memref) outs(%alloca : memref) { +// CHECK-NEXT: ^bb0(%in: f32, %out: f32): +// CHECK-NEXT: linalg.yield %in : f32 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +//constant-access +module @constant_access{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %ci324 = arith.constant 4.0 : f32 + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %mul = arith.mulf %ld, %ci324 : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} + +//constant-mem-access +module @constant_mem_access{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 4 to 17 step 2 { + %ld = affine.load %18[3*%arg4] : memref + %ld2 = affine.load %18[%c4] : memref + %mul = arith.mulf %ld, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} + +//without-if +module @no_if{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + affine.store %ld, %19[%arg4] : memref + } + return + } +} + +//arith.mul +module @arith_mul{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %mul = arith.mulf %ld, %ld : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} + +//arith.add +module @arith_add{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg4] : memref + %ld2 = affine.load %20[%arg4] : memref + %add = arith.addf %ld1, %ld2 : f32 + %mul = arith.mulf %add, %add : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} + +//Conditional arith +module @cond_arith{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %if = scf.if %12 -> f32 { + %mul = arith.mulf %ld, %ld : f32 + scf.yield %mul : f32 + } else { + scf.yield %ld : f32 + } + affine.store %if, %19[%arg4] : memref + } + return + } +} + +//TODO: reduction +module @reduction{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + %sum_0 = arith.constant 0.0 : f32 + %red = affine.for %arg4 = 0 to 17 step 1 iter_args(%sum_iter = %sum_0) -> f32 { + %ld1 = affine.load %18[%arg4] : memref + %sum_next = arith.addf %sum_iter, %ld1 : f32 + affine.yield %sum_next : f32 + } + affine.store %red, %19[0] : memref + return + } +} + +//TODO: Conditional store-1 +module @cond_store_1 { + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %mul = arith.mulf %ld, %ld : f32 + scf.if %12 { + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//TODO: Conditional store-2 +module @cond_store_2{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + scf.if %12 { + %mul = arith.mulf %ld, %ld : f32 + affine.store %mul, %19[%arg4] : memref + } else { + affine.store %ld, %19[%arg4] : memref + } + } + return + } +} + +//Parallel for +module @parallel_for{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %mul = arith.mulf %ld, %ld : f32 + affine.store %mul, %19[%arg4] : memref + } + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg4] : memref + %ld2 = affine.load %20[%arg4] : memref + %add = arith.addf %ld1, %ld2 : f32 + %mul = arith.mulf %add, %add : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} ////Fors inside for -//module @for_within_for_3{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg4+2*%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} +module @for_within_for{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//Fors inside for +module @for_within_for_2{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3+2*%arg4] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//Fors inside for +module @for_within_for_3{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3+2*%arg4] : memref + %ld2 = affine.load %18[%arg3] : memref + %ld3 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + return + } +} + +//Fors inside for +module @for_within_for_4{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg4+2*%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//Fors no-loop dependency +module @for_no_loop_dependency{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 15 { + %ld1 = affine.load %18[0] : memref + affine.store %ld1, %19[0] : memref + } + return + } +} +//Fors no-loop dependency +module @for_2_levels_no_loop_dependency{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + affine.for %arg3 = 0 to 15 { + %ld1 = affine.load %18[%arg4] : memref + affine.store %ld1, %19[%arg4] : memref + } + } + return + } +} //Fors inside for module @for_3_levels_0{ func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { @@ -386,13 +425,13 @@ module @for_3_levels_0{ %17 = arith.divui %16, %c4 : index %21 = arith.muli %16, %c4 : index %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { + affine.for %arg3 = 0 to 15 { affine.for %arg4 = 0 to 17 { affine.for %arg5 = 0 to 21 { %ld1 = affine.load %18[%arg3] : memref %ld2 = affine.load %20[%arg4] : memref %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref + affine.store %mul, %19[%arg5] : memref } } } @@ -400,166 +439,167 @@ module @for_3_levels_0{ } } -////Fors inside for -//module @for_3_levels_1{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg5 = 0 to 21 { -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// } -// return -// } -//} +//Fors inside for +module @for_3_levels_1{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg5 = 0 to 21 { + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + } + return + } +} -////Fors inside for -//module @for_3_levels_1{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// affine.for %arg5 = 0 to 21 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %ld3 = affine.load %23[%arg5] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// %mul2 = arith.mulf %mul, %ld3 : f32 -// affine.store %mul2, %19[%arg4] : memref -// } -// } -// } -// return -// } -//} +//Fors inside for +module @for_3_levels_2{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %ld3 = affine.load %23[%arg5] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + } + return + } +} -////Fors inside for -//module @for_3_levels_2{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// affine.for %arg5 = 0 to 21 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %ld3 = affine.load %20[%arg5] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// %mul2 = arith.mulf %mul, %ld3 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// } -// return -// } -//} +//Fors inside for +module @for_3_levels_3{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %ld3 = affine.load %20[%arg5] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + } + return + } +} //Fors inside for -//module @for_3_levels_2{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// affine.for %arg5 = 0 to 21 { -// %ld1 = affine.load %18[%arg3+%arg4] : memref -// %ld2 = affine.load %20[%arg4+%arg5] : memref -// %ld3 = affine.load %20[%arg5+%arg3] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// %mul2 = arith.mulf %mul, %ld3 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// } -// return -// } -//} +module @for_3_levels_4{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3+%arg4] : memref + %ld2 = affine.load %20[%arg4+%arg5] : memref + %ld3 = affine.load %20[%arg5+%arg3] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + } + return + } +} -//#map = affine_map<(d0)[s0] -> (s0)> -//#map1 = affine_map<(d0) -> (d0)> -//module @for_within_for2 { -// func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { -// %c17 = arith.constant 17 : index -// %c4 = arith.constant 4 : index -// %0 = arith.index_cast %arg1 : i32 to index -// %1 = arith.muli %0, %c4 : index -// %2 = arith.divui %1, %c4 : index -// %alloca = memref.alloca(%2) : memref -// affine.for %arg4 = 0 to 21 { -// %3 = "polygeist.submap"(%arg2, %arg4, %c17) <{map = #map}> : (memref, index, index) -> memref -// %4 = "polygeist.submap"(%arg3, %c17) <{map = #map1}> : (memref, index) -> memref -// %5 = "polygeist.submap"(%alloca, %c17) <{map = #map1}> : (memref, index) -> memref -// linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { -// ^bb0(%in: f32, %in_0: f32, %out: f32): -// %6 = arith.mulf %in, %in_0 : f32 -// linalg.yield %6 : f32 -// } -// } -// return -// } -//} +//Intermediate raising +#map = affine_map<(d0)[s0] -> (s0)> +#map1 = affine_map<(d0) -> (d0)> +module @for_within_for2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.for %arg4 = 0 to 21 { + %3 = "polygeist.submap"(%arg2, %arg4, %c17) <{map = #map}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17) <{map = #map1}> : (memref, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17) <{map = #map1}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + } + return + } +} -////Parallel fors inside for -//module @parallel_fors_inside_for { -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 17 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %add = arith.addf %ld1, %ld2 : f32 -// %mul = arith.mulf %add, %add : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////matrix-mul iter arg +//Parallel fors inside for +module @parallel_fors_inside_for { + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 17 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %add = arith.addf %ld1, %ld2 : f32 + %mul = arith.mulf %add, %add : f32 + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//matrix-mul iter arg //module @matmul_1 { // memref.global @out : memref<32x8xi32> = uninitialized // memref.global @im2 : memref<8x8xi32> = uninitialized From b0108e37dba0d223863962ca11192599d695dc42 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 7 Aug 2024 07:12:45 +0000 Subject: [PATCH 20/77] Non iter-arg variant of matrix-mul and conv are now raised to linalg.generic --- test/polygeist-opt/linalgraise.mlir | 118 ++++++++++++++-------------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index 627021c094a5..069891879f30 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -532,12 +532,12 @@ module @for_3_levels_4{ affine.for %arg3 = 0 to 21 { affine.for %arg4 = 0 to 17 { affine.for %arg5 = 0 to 21 { - %ld1 = affine.load %18[%arg3+%arg4] : memref - %ld2 = affine.load %20[%arg4+%arg5] : memref - %ld3 = affine.load %20[%arg5+%arg3] : memref + %ld1 = affine.load %18[%arg3+4*%arg4+3] : memref + %ld2 = affine.load %20[7*%arg4+%arg5+2] : memref + %ld3 = affine.load %20[%arg5+2*%arg3] : memref %mul = arith.mulf %ld1, %ld2 : f32 %mul2 = arith.mulf %mul, %ld3 : f32 - affine.store %mul, %19[%arg4] : memref + affine.store %mul2, %19[%arg4] : memref } } } @@ -599,7 +599,7 @@ module @parallel_fors_inside_for { } } -//matrix-mul iter arg +////matrix-mul iter arg //module @matmul_1 { // memref.global @out : memref<32x8xi32> = uninitialized // memref.global @im2 : memref<8x8xi32> = uninitialized @@ -624,33 +624,33 @@ module @parallel_fors_inside_for { // return %c0_i32 : i32 // } //} -// -////matrix-mul alias issue -//module @matmul_2 { -// memref.global @out : memref<128x32xi32> = uninitialized -// memref.global @im2 : memref<64x32xi32> = uninitialized -// memref.global @im1 : memref<128x64xi32> = uninitialized -// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c0_i32 = arith.constant 0 : i32 -// %0 = memref.get_global @im1 : memref<128x64xi32> -// %1 = memref.get_global @im2 : memref<64x32xi32> -// %2 = memref.get_global @out : memref<128x32xi32> -// affine.for %arg0 = 0 to 128 { -// affine.for %arg1 = 0 to 32 { -// affine.for %arg2 = 0 to 64 { -// %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> -// %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> -// %5 = arith.muli %3, %4 : i32 -// %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> -// %7 = arith.addi %6, %5 : i32 -// affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> -// } -// } -// } -// return %c0_i32 : i32 -// } -//} -// + +//matrix-mul extra load-store variant +module @matmul_2 { + memref.global @out : memref<128x32xi32> = uninitialized + memref.global @im2 : memref<64x32xi32> = uninitialized + memref.global @im1 : memref<128x64xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im1 : memref<128x64xi32> + %1 = memref.get_global @im2 : memref<64x32xi32> + %2 = memref.get_global @out : memref<128x32xi32> + affine.for %arg0 = 0 to 128 { + affine.for %arg1 = 0 to 32 { + affine.for %arg2 = 0 to 64 { + %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> + %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> + } + } + } + return %c0_i32 : i32 + } +} + ////conv (with inner loop accumulate) ////How to deal with IR in outer loops as well? //module @conv_1{ @@ -681,31 +681,31 @@ module @parallel_fors_inside_for { // } //} // -////conv (direct store) -//module @conv_2 { -// memref.global @out : memref<512x64xi32> = uninitialized -// memref.global @filter : memref<4x4xi32> = uninitialized -// memref.global @im : memref<515x67xi32> = uninitialized -// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c0_i32 = arith.constant 0 : i32 -// %0 = memref.get_global @im : memref<515x67xi32> -// %1 = memref.get_global @filter : memref<4x4xi32> -// %2 = memref.get_global @out : memref<512x64xi32> -// affine.for %arg0 = 0 to 512 { -// affine.for %arg1 = 0 to 64 { -// affine.for %arg2 = 0 to 4 { -// affine.for %arg3 = 0 to 4 { -// %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> -// %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> -// %5 = arith.muli %3, %4 : i32 -// %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> -// %7 = arith.addi %6, %5 : i32 -// affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> -// } -// } -// } -// } -// return %c0_i32 : i32 -// } -//} +//conv (direct store) +module @conv_2 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 64 { + affine.for %arg2 = 0 to 4 { + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + } + } + } + } + return %c0_i32 : i32 + } +} \ No newline at end of file From 4362c80fd643d9d08e4cfb7a70f7c0f596708292 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 21 Aug 2024 00:35:01 +0000 Subject: [PATCH 21/77] submap canonicalizer implemented --- include/polygeist/PolygeistOps.td | 4 +- lib/polygeist/Ops.cpp | 55 +- lib/polygeist/Passes/RaiseToLinalg.cpp | 17 +- test/polygeist-opt/linalgraise.mlir | 1295 ++++++++++---------- test/polygeist-opt/submapcanonicalize.mlir | 41 + tools/polygeist-opt/polygeist-opt.cpp | 1 + 6 files changed, 754 insertions(+), 659 deletions(-) create mode 100644 test/polygeist-opt/submapcanonicalize.mlir diff --git a/include/polygeist/PolygeistOps.td b/include/polygeist/PolygeistOps.td index aeac713fdc9b..ff59deb22bbd 100644 --- a/include/polygeist/PolygeistOps.td +++ b/include/polygeist/PolygeistOps.td @@ -270,8 +270,8 @@ def SubmapOp : Polygeist_Op<"submap", [Pure, ViewLikeOpInterface]> { let hasCanonicalizer = 1; let extraClassDeclaration = [{ - ::mlir::ValueRange getSymbols() { return getOperands().slice(1, getMap().getNumSymbols()+1); } - ::mlir::ValueRange getSizes() { return getOperands().slice(getMap().getNumSymbols()+1, getMap().getNumSymbols() + getType().getShape().size()+1); } + ::mlir::ValueRange getSymbols() { return getOperands().slice(1, getMap().getNumSymbols()); } + ::mlir::ValueRange getSizes() { return getOperands().slice(getMap().getNumSymbols()+1, getType().getShape().size()); } ::mlir::Value getViewSource() { return getMemref(); } }]; } diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 7a4101f0a936..cb486026112b 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5881,7 +5881,6 @@ LogicalResult GetFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -/* class LoadSubMap final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -5891,31 +5890,53 @@ class LoadSubMap final : public OpRewritePattern { auto subMapOp = op.getMemRef().getDefiningOp(); if (!subMapOp) return failure(); - auto submap_map = ref.getAffineMap(); - auto submap_operands = ref.getAffineMapOperands(); - auto source_memref = ref.getMemref(); + auto submap_map = subMapOp.getMap(); + auto submap_operands = subMapOp.getSymbols(); + auto source_memref = subMapOp.getMemref(); - auto load_map = ref.getAffineMap(); - SmallVector operands0 = op.getMapOperands(); + auto load_map = op.getAffineMap(); + auto load_operands = op.getMapOperands(); + + auto new_map = submap_map.compose(load_map); + + SmallVector operands; + operands.append(load_operands.begin(), load_operands.begin() + load_map.getNumDims()); + operands.append(submap_operands.begin(), submap_operands.end()); + operands.append(load_operands.begin() + load_map.getNumDims(), load_operands.end()); + + rewriter.replaceOpWithNewOp(op, source_memref, new_map, operands); + return success(); + } +}; - // %m = polygeist.submap submap_map(%submap_operands) %source_memref : memref -> memref - // %a = affine.load %m[load_map(%load_operands)] - // -> - // %a = affine.load %source_memref[load_map(submap_map(%load_operands, %submap_operands))] - auto new_map = load_map.compose(submap_map); - auto new_operands = llvm::concat(load_operands, submap_operands) +class StoreSubMap final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; - rewriter.replaceOpWithNewOp(op.getLoc(), sourceMemref, ); + LogicalResult matchAndRewrite(affine::AffineStoreOp op, + PatternRewriter &rewriter) const override { + auto subMapOp = op.getMemRef().getDefiningOp(); + if (!subMapOp) return failure(); + auto submap_map = subMapOp.getMap(); + auto submap_operands = subMapOp.getSymbols(); + auto source_memref = subMapOp.getMemref(); + + auto load_map = op.getAffineMap(); + auto load_operands = op.getMapOperands(); + + auto new_map = submap_map.compose(load_map); + SmallVector operands; + operands.append(load_operands.begin(), load_operands.begin() + load_map.getNumDims()); + operands.append(submap_operands.begin(), submap_operands.end()); + operands.append(load_operands.begin() + load_map.getNumDims(), load_operands.end()); - // shift one map over by the size of other # symbols/dims, replace with new affine load with composed map + rewriter.replaceOpWithNewOp(op, op.getValue(), source_memref, new_map, operands); return success(); } }; -*/ -// TODO StoreSubMap OpFoldResult mlir::polygeist::SubmapOp::fold(mlir::polygeist::SubmapOp::FoldAdaptor adaptor) { // TODO if submap is identity return nothing @@ -5925,5 +5946,5 @@ OpFoldResult mlir::polygeist::SubmapOp::fold(mlir::polygeist::SubmapOp::FoldAdap void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - //results.insert(context); + results.insert(context); } diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index b6668e57ee70..61c589c65daf 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -306,6 +306,8 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, if (auto SM = dyn_cast(defOp)) { auto submap = SM.getMap(); + //TODO: Do we achieve anything with this compose? + //As lgMap in our case is 1 to 1 identity map auto composeMap = submap.compose(lgMap); SmallVector operands0; @@ -462,6 +464,7 @@ struct AffineForOpRaising : public OpRewritePattern { // Additionally, for each load/store, remember what conditions are // required for that load or store to execute. auto result = loop->walk([&](Operation *op) { + llvm::outs()<< op->getName() << "\n"; if (op == loop) return WalkResult::advance(); // TODO extend this, any non-memory operation is also legal here. @@ -781,16 +784,22 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO if linalg generic exists, make this iterator type prepend to the // existing iterators + bool is_parallel = stores_map.size() == 0; + // TODO determine if linalg generic, whether to create parallel or reduction by looking at memory patterns of maps + if (linalgGenerics.size() == 1) { - for (auto attr : linalgGenerics[0].second.getIteratorTypesArray()) - iteratorTypes.push_back(utils::IteratorType::parallel); + // determine whether now we write to ourselves } - // TODO determine if linalg generic, whether to create parallel or reduction by looking at memory patterns of maps - iteratorTypes.push_back((stores_map.size() == 0) + iteratorTypes.push_back(is_parallel ? utils::IteratorType::parallel : utils::IteratorType::reduction); + if (linalgGenerics.size() == 1) { + for (auto attr : linalgGenerics[0].second.getIteratorTypesArray()) + iteratorTypes.push_back(attr); + } + StringAttr empty = StringAttr::get(loop.getContext()); auto genericOp = rewriter.create( loop.getLoc(), TypeRange(), inputs, outputs, affineMaps, iteratorTypes, diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index 069891879f30..5c77b25a10df 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -1,656 +1,656 @@ -//// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s +////// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s +//// +//module { +// func.func @main0(%12 : i1, %18 : memref<32xf32> ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %19 = memref.alloca() : memref<32xf32> +// scf.if %12 { +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref<32xf32> +// affine.store %ld, %19[%arg4] : memref<32xf32> +// } +// } +// return +// } +// +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// scf.if %12 { +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// affine.store %ld, %19[%arg4] : memref +// } +// } +// return +// } +// +// +// func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// scf.if %12 { +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[3 * %arg4] : memref +// %ld2 = affine.load %18[0] : memref +// %fadd = arith.addf %ld, %ld2 : f32 +// affine.store %fadd, %19[%arg4 + 17] : memref +// } +// } +// return +// } +// +//} +// +//// CHECK: #map = affine_map<(d0) -> (d0)> +//// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { +//// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index +//// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index +//// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index +//// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index +//// CHECK-NEXT: scf.if %[[arg0]] { +//// TODO note that presently we do not ensure that the memrefs are sliced to the right size as the space requires +//// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg2 : memref) outs(%alloca : memref) { +//// CHECK-NEXT: ^bb0(%in: f32, %out: f32): +//// CHECK-NEXT: linalg.yield %in : f32 +//// CHECK-NEXT: } +//// CHECK-NEXT: } +//// CHECK-NEXT: } +// +////constant-access +//module @constant_access{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %ci324 = arith.constant 4.0 : f32 +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ci324 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////constant-mem-access +//module @constant_mem_access{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 4 to 17 step 2 { +// %ld = affine.load %18[3*%arg4] : memref +// %ld2 = affine.load %18[%c4] : memref +// %mul = arith.mulf %ld, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////without-if +//module @no_if{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// affine.store %ld, %19[%arg4] : memref +// } +// return +// } +//} +// +////arith.mul +//module @arith_mul{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ld : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////arith.add +//module @arith_add{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg4] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +////Conditional arith +//module @cond_arith{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %if = scf.if %12 -> f32 { +// %mul = arith.mulf %ld, %ld : f32 +// scf.yield %mul : f32 +// } else { +// scf.yield %ld : f32 +// } +// affine.store %if, %19[%arg4] : memref +// } +// return +// } +//} +// +////TODO: reduction +//module @reduction{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// %sum_0 = arith.constant 0.0 : f32 +// %red = affine.for %arg4 = 0 to 17 step 1 iter_args(%sum_iter = %sum_0) -> f32 { +// %ld1 = affine.load %18[%arg4] : memref +// %sum_next = arith.addf %sum_iter, %ld1 : f32 +// affine.yield %sum_next : f32 +// } +// affine.store %red, %19[0] : memref +// return +// } +//} +// +////TODO: Conditional store-1 +//module @cond_store_1 { +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ld : f32 +// scf.if %12 { +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////TODO: Conditional store-2 +//module @cond_store_2{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// scf.if %12 { +// %mul = arith.mulf %ld, %ld : f32 +// affine.store %mul, %19[%arg4] : memref +// } else { +// affine.store %ld, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////Parallel for +//module @parallel_for{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ld : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg4] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +//} +// +//////Fors inside for +//module @for_within_for{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} // -module { - func.func @main0(%12 : i1, %18 : memref<32xf32> ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %19 = memref.alloca() : memref<32xf32> - scf.if %12 { - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref<32xf32> - affine.store %ld, %19[%arg4] : memref<32xf32> - } - } - return - } - - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - scf.if %12 { - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - affine.store %ld, %19[%arg4] : memref - } - } - return - } - - - func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - scf.if %12 { - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[3 * %arg4] : memref - %ld2 = affine.load %18[0] : memref - %fadd = arith.addf %ld, %ld2 : f32 - affine.store %fadd, %19[%arg4 + 17] : memref - } - } - return - } - -} - -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { -// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index -// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index -// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index -// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index -// CHECK-NEXT: scf.if %[[arg0]] { -// TODO note that presently we do not ensure that the memrefs are sliced to the right size as the space requires -// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg2 : memref) outs(%alloca : memref) { -// CHECK-NEXT: ^bb0(%in: f32, %out: f32): -// CHECK-NEXT: linalg.yield %in : f32 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } - -//constant-access -module @constant_access{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %ci324 = arith.constant 4.0 : f32 - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ci324 : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//constant-mem-access -module @constant_mem_access{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 4 to 17 step 2 { - %ld = affine.load %18[3*%arg4] : memref - %ld2 = affine.load %18[%c4] : memref - %mul = arith.mulf %ld, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//without-if -module @no_if{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - affine.store %ld, %19[%arg4] : memref - } - return - } -} - -//arith.mul -module @arith_mul{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ld : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//arith.add -module @arith_add{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg4] : memref - %ld2 = affine.load %20[%arg4] : memref - %add = arith.addf %ld1, %ld2 : f32 - %mul = arith.mulf %add, %add : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - -//Conditional arith -module @cond_arith{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %if = scf.if %12 -> f32 { - %mul = arith.mulf %ld, %ld : f32 - scf.yield %mul : f32 - } else { - scf.yield %ld : f32 - } - affine.store %if, %19[%arg4] : memref - } - return - } -} - -//TODO: reduction -module @reduction{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - %sum_0 = arith.constant 0.0 : f32 - %red = affine.for %arg4 = 0 to 17 step 1 iter_args(%sum_iter = %sum_0) -> f32 { - %ld1 = affine.load %18[%arg4] : memref - %sum_next = arith.addf %sum_iter, %ld1 : f32 - affine.yield %sum_next : f32 - } - affine.store %red, %19[0] : memref - return - } -} - -//TODO: Conditional store-1 -module @cond_store_1 { - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ld : f32 - scf.if %12 { - affine.store %mul, %19[%arg4] : memref - } - } - return - } -} - -//TODO: Conditional store-2 -module @cond_store_2{ - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - scf.if %12 { - %mul = arith.mulf %ld, %ld : f32 - affine.store %mul, %19[%arg4] : memref - } else { - affine.store %ld, %19[%arg4] : memref - } - } - return - } -} - -//Parallel for -module @parallel_for{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ld : f32 - affine.store %mul, %19[%arg4] : memref - } - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg4] : memref - %ld2 = affine.load %20[%arg4] : memref - %add = arith.addf %ld1, %ld2 : f32 - %mul = arith.mulf %add, %add : f32 - affine.store %mul, %19[%arg4] : memref - } - return - } -} - ////Fors inside for -module @for_within_for{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - } - return - } -} - -//Fors inside for -module @for_within_for_2{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3+2*%arg4] : memref - %ld2 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - } - return - } -} - -//Fors inside for -module @for_within_for_3{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3+2*%arg4] : memref - %ld2 = affine.load %18[%arg3] : memref - %ld3 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - %mul2 = arith.mulf %mul, %ld3 : f32 - affine.store %mul2, %19[%arg4] : memref - } - } - return - } -} - -//Fors inside for -module @for_within_for_4{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg4+2*%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - } - return - } -} - -//Fors no-loop dependency -module @for_no_loop_dependency{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 15 { - %ld1 = affine.load %18[0] : memref - affine.store %ld1, %19[0] : memref - } - return - } -} -//Fors no-loop dependency -module @for_2_levels_no_loop_dependency{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - affine.for %arg3 = 0 to 15 { - %ld1 = affine.load %18[%arg4] : memref - affine.store %ld1, %19[%arg4] : memref - } - } - return - } -} -//Fors inside for -module @for_3_levels_0{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 15 { - affine.for %arg4 = 0 to 17 { - affine.for %arg5 = 0 to 21 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg5] : memref - } - } - } - return - } -} - -//Fors inside for -module @for_3_levels_1{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg5 = 0 to 21 { - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - } - } - return - } -} - -//Fors inside for -module @for_3_levels_2{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - affine.for %arg5 = 0 to 21 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %ld3 = affine.load %23[%arg5] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - %mul2 = arith.mulf %mul, %ld3 : f32 - affine.store %mul2, %19[%arg4] : memref - } - } - } - return - } -} - -//Fors inside for -module @for_3_levels_3{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - affine.for %arg5 = 0 to 21 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %ld3 = affine.load %20[%arg5] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - %mul2 = arith.mulf %mul, %ld3 : f32 - affine.store %mul2, %19[%arg4] : memref - } - } - } - return - } -} - -//Fors inside for -module @for_3_levels_4{ - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %21 = arith.muli %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 21 { - affine.for %arg4 = 0 to 17 { - affine.for %arg5 = 0 to 21 { - %ld1 = affine.load %18[%arg3+4*%arg4+3] : memref - %ld2 = affine.load %20[7*%arg4+%arg5+2] : memref - %ld3 = affine.load %20[%arg5+2*%arg3] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - %mul2 = arith.mulf %mul, %ld3 : f32 - affine.store %mul2, %19[%arg4] : memref +//module @for_within_for_2{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3+2*%arg4] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////Fors inside for +//module @for_within_for_3{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3+2*%arg4] : memref +// %ld2 = affine.load %18[%arg3] : memref +// %ld3 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// %mul2 = arith.mulf %mul, %ld3 : f32 +// affine.store %mul2, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////Fors inside for +//module @for_within_for_4{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg4+2*%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +////Fors no-loop dependency +//module @for_no_loop_dependency{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 15 { +// %ld1 = affine.load %18[0] : memref +// affine.store %ld1, %19[0] : memref +// } +// return +// } +//} +////Fors no-loop dependency +//module @for_2_levels_no_loop_dependency{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// affine.for %arg3 = 0 to 15 { +// %ld1 = affine.load %18[%arg4] : memref +// affine.store %ld1, %19[%arg4] : memref +// } +// } +// return +// } +//} +////Fors inside for +//module @for_3_levels_0{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 15 { +// affine.for %arg4 = 0 to 17 { +// affine.for %arg5 = 0 to 21 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg5] : memref +// } +// } +// } +// return +// } +//} +// +////Fors inside for +//module @for_3_levels_1{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg5 = 0 to 21 { +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// } +// return +// } +//} +// +////Fors inside for +//module @for_3_levels_2{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// affine.for %arg5 = 0 to 21 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %ld3 = affine.load %23[%arg5] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// %mul2 = arith.mulf %mul, %ld3 : f32 +// affine.store %mul2, %19[%arg4] : memref +// } +// } +// } +// return +// } +//} +// +////Fors inside for +//module @for_3_levels_3{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// affine.for %arg5 = 0 to 21 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %ld3 = affine.load %20[%arg5] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// %mul2 = arith.mulf %mul, %ld3 : f32 +// affine.store %mul2, %19[%arg4] : memref +// } +// } +// } +// return +// } +//} +// +////Fors inside for +//module @for_3_levels_4{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %21 = arith.muli %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 21 { +// affine.for %arg4 = 0 to 17 { +// affine.for %arg5 = 0 to 21 { +// %ld1 = affine.load %18[%arg3+4*%arg4+3] : memref +// %ld2 = affine.load %20[7*%arg4+%arg5+2] : memref +// %ld3 = affine.load %20[%arg5+2*%arg3] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// %mul2 = arith.mulf %mul, %ld3 : f32 +// affine.store %mul2, %19[%arg4] : memref +// } +// } +// } +// return +// } +//} +// +////Intermediate raising +//#map = affine_map<(d0)[s0] -> (s0)> +//#map1 = affine_map<(d0) -> (d0)> +//module @for_within_for2 { +// func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { +// %c17 = arith.constant 17 : index +// %c4 = arith.constant 4 : index +// %0 = arith.index_cast %arg1 : i32 to index +// %1 = arith.muli %0, %c4 : index +// %2 = arith.divui %1, %c4 : index +// %alloca = memref.alloca(%2) : memref +// affine.for %arg4 = 0 to 21 { +// %3 = "polygeist.submap"(%arg2, %arg4, %c17) <{map = #map}> : (memref, index, index) -> memref +// %4 = "polygeist.submap"(%arg3, %c17) <{map = #map1}> : (memref, index) -> memref +// %5 = "polygeist.submap"(%alloca, %c17) <{map = #map1}> : (memref, index) -> memref +// linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { +// ^bb0(%in: f32, %in_0: f32, %out: f32): +// %6 = arith.mulf %in, %in_0 : f32 +// linalg.yield %6 : f32 +// } +// } +// return +// } +//} +// +////Parallel fors inside for +//module @parallel_fors_inside_for { +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 17 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +//} +// +//matrix-mul iter arg +module @matmul_1 { + memref.global @out : memref<32x8xi32> = uninitialized + memref.global @im2 : memref<8x8xi32> = uninitialized + memref.global @im1 : memref<32x8xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im1 : memref<32x8xi32> + %1 = memref.get_global @im2 : memref<8x8xi32> + %2 = memref.get_global @out : memref<32x8xi32> + affine.for %arg0 = 0 to 32 { + affine.for %arg1 = 0 to 8 { + %3 = affine.for %arg2 = 0 to 8 iter_args(%arg3 = %c0_i32) -> (i32) { + %4 = affine.load %0[%arg0, %arg2] : memref<32x8xi32> + %5 = affine.load %1[%arg2, %arg1] : memref<8x8xi32> + %6 = arith.muli %4, %5 : i32 + %7 = arith.addi %arg3, %6 : i32 + affine.yield %7 : i32 } + affine.store %3, %2[%arg0, %arg1] : memref<32x8xi32> } } - return - } -} - -//Intermediate raising -#map = affine_map<(d0)[s0] -> (s0)> -#map1 = affine_map<(d0) -> (d0)> -module @for_within_for2 { - func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { - %c17 = arith.constant 17 : index - %c4 = arith.constant 4 : index - %0 = arith.index_cast %arg1 : i32 to index - %1 = arith.muli %0, %c4 : index - %2 = arith.divui %1, %c4 : index - %alloca = memref.alloca(%2) : memref - affine.for %arg4 = 0 to 21 { - %3 = "polygeist.submap"(%arg2, %arg4, %c17) <{map = #map}> : (memref, index, index) -> memref - %4 = "polygeist.submap"(%arg3, %c17) <{map = #map1}> : (memref, index) -> memref - %5 = "polygeist.submap"(%alloca, %c17) <{map = #map1}> : (memref, index) -> memref - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %6 = arith.mulf %in, %in_0 : f32 - linalg.yield %6 : f32 - } - } - return - } -} - -//Parallel fors inside for -module @parallel_fors_inside_for { - func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 17 { - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref - } - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref - %add = arith.addf %ld1, %ld2 : f32 - %mul = arith.mulf %add, %add : f32 - affine.store %mul, %19[%arg4] : memref - } - } - return + return %c0_i32 : i32 } } -////matrix-mul iter arg -//module @matmul_1 { -// memref.global @out : memref<32x8xi32> = uninitialized -// memref.global @im2 : memref<8x8xi32> = uninitialized -// memref.global @im1 : memref<32x8xi32> = uninitialized +////matrix-mul extra load-store variant +//module @matmul_2 { +// memref.global @out : memref<128x32xi32> = uninitialized +// memref.global @im2 : memref<64x32xi32> = uninitialized +// memref.global @im1 : memref<128x64xi32> = uninitialized // func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { // %c0_i32 = arith.constant 0 : i32 -// %0 = memref.get_global @im1 : memref<32x8xi32> -// %1 = memref.get_global @im2 : memref<8x8xi32> -// %2 = memref.get_global @out : memref<32x8xi32> -// affine.for %arg0 = 0 to 32 { -// affine.for %arg1 = 0 to 8 { -// %3 = affine.for %arg2 = 0 to 8 iter_args(%arg3 = %c0_i32) -> (i32) { -// %4 = affine.load %0[%arg0, %arg2] : memref<32x8xi32> -// %5 = affine.load %1[%arg2, %arg1] : memref<8x8xi32> -// %6 = arith.muli %4, %5 : i32 -// %7 = arith.addi %arg3, %6 : i32 -// affine.yield %7 : i32 +// %0 = memref.get_global @im1 : memref<128x64xi32> +// %1 = memref.get_global @im2 : memref<64x32xi32> +// %2 = memref.get_global @out : memref<128x32xi32> +// affine.for %arg0 = 0 to 128 { +// affine.for %arg1 = 0 to 32 { +// affine.for %arg2 = 0 to 64 { +// %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> +// %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> +// %5 = arith.muli %3, %4 : i32 +// %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> +// %7 = arith.addi %6, %5 : i32 +// affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> // } -// affine.store %3, %2[%arg0, %arg1] : memref<32x8xi32> // } // } // return %c0_i32 : i32 // } //} -//matrix-mul extra load-store variant -module @matmul_2 { - memref.global @out : memref<128x32xi32> = uninitialized - memref.global @im2 : memref<64x32xi32> = uninitialized - memref.global @im1 : memref<128x64xi32> = uninitialized - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im1 : memref<128x64xi32> - %1 = memref.get_global @im2 : memref<64x32xi32> - %2 = memref.get_global @out : memref<128x32xi32> - affine.for %arg0 = 0 to 128 { - affine.for %arg1 = 0 to 32 { - affine.for %arg2 = 0 to 64 { - %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> - %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> - %5 = arith.muli %3, %4 : i32 - %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> - %7 = arith.addi %6, %5 : i32 - affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> - } - } - } - return %c0_i32 : i32 - } -} - ////conv (with inner loop accumulate) ////How to deal with IR in outer loops as well? //module @conv_1{ @@ -708,4 +708,27 @@ module @conv_2 { return %c0_i32 : i32 } } + +module @submap_test { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main(%arg0 : index, %arg1 : index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + affine.for %arg2 = 0 to 4 { + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + } + } + return %c0_i32 : i32 + } +} \ No newline at end of file diff --git a/test/polygeist-opt/submapcanonicalize.mlir b/test/polygeist-opt/submapcanonicalize.mlir new file mode 100644 index 000000000000..3e186911f677 --- /dev/null +++ b/test/polygeist-opt/submapcanonicalize.mlir @@ -0,0 +1,41 @@ +// RUN: polygeist-opt -canonicalize %s | FileCheck %s +#map = affine_map<(d0)[s0, s1] -> (d0 * s0, d0 * s1)> +module { + func.func private @use(i32) + func.func @f(%arg0: memref, %arg1 : index, %arg2 : index, %arg3 : index) { + + %submap = "polygeist.submap"(%arg0, %arg1, %arg2) <{map = #map}> : (memref, index, index) -> memref + + affine.for %arg4 = 0 to 10 { + %l = affine.load %submap[5 + %arg4 + symbol(%arg3)] : memref + func.call @use(%l) : (i32) -> () + affine.yield + } + return + } + + func.func @g(%arg0: memref, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : i32) { + %submap = "polygeist.submap"(%arg0, %arg1, %arg2) <{map = #map}> : (memref, index, index) -> memref + affine.for %arg5 = 0 to 10 { + affine.store %arg4, %submap[5 + %arg5 + symbol(%arg3)] : memref + affine.yield + } + return + } +} + + +// CHECK: func.func @f(%arg0: memref, %arg1: index, %arg2: index, %arg3: index) { +// CHECK-NEXT: affine.for %arg4 = 0 to 10 { +// CHECK-NEXT: %0 = affine.load %arg0[(%arg4 + symbol(%arg3) + 5) * symbol(%arg1), (%arg4 + symbol(%arg3) + 5) * symbol(%arg2)] : memref +// CHECK-NEXT: func.call @use(%0) : (i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +// CHECK: func.func @g(%arg0: memref, %arg1: index, %arg2: index, %arg3: index, %arg4: i32) { +// CHECK-NEXT: affine.for %arg5 = 0 to 10 { +// CHECK-NEXT: affine.store %arg4, %arg0[(%arg5 + symbol(%arg3) + 5) * symbol(%arg1), (%arg5 + symbol(%arg3) + 5) * symbol(%arg2)] : memref +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 64a7e7a35293..7759db83c573 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -77,6 +77,7 @@ int main(int argc, char **argv) { mlir::registerLoopInvariantCodeMotionPass(); mlir::registerConvertSCFToOpenMPPass(); mlir::affine::registerAffinePasses(); + mlir::registerLinalgPasses(); registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { LLVM::LLVMFunctionType::attachInterface(*ctx); From 77c8168ceb1db37ef968fc226a660fab666fc8af Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 22 Aug 2024 17:09:38 +0000 Subject: [PATCH 22/77] Added reduction loops for linalg --- lib/polygeist/Passes/RaiseToLinalg.cpp | 36 ++++++++++++++++---------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 61c589c65daf..03bda7dbba02 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -117,9 +117,10 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { // indices into `newval` such that // indexing `newval[map(indices)]` produces the same result as indexing the // original map. - +// check_reduction is set true, when passed from store/linalg.generic's output variable. +// And it is returned true, only if index was not encountered in oldmap operands and check_reduction was set true. Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, - Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands) { + Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands, bool &check_reduction) { assert(oldmap_operands.size() == oldmap.getNumSymbols() + oldmap.getNumDims()); //Operands which don't correspond to indices SmallVector operands_without_indices; @@ -135,7 +136,11 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, else dimidx = i; } - + if((dimidx == -1) && (check_reduction)) + check_reduction = true; + else + check_reduction = false; + SmallVector dimReplacements; size_t validSims = 0; size_t validDims = 0; @@ -457,6 +462,8 @@ struct AffineForOpRaising : public OpRewritePattern { SmallVector, AffineLoadOp>> loads; SmallVector, AffineStoreOp>> stores; SmallVector, GenericOp>> linalgGenerics; + bool check_reduction; + // TODO Also collect all the linalg generics! // Check that the only operations within the region are either: @@ -464,7 +471,6 @@ struct AffineForOpRaising : public OpRewritePattern { // Additionally, for each load/store, remember what conditions are // required for that load or store to execute. auto result = loop->walk([&](Operation *op) { - llvm::outs()<< op->getName() << "\n"; if (op == loop) return WalkResult::advance(); // TODO extend this, any non-memory operation is also legal here. @@ -664,9 +670,10 @@ struct AffineForOpRaising : public OpRewritePattern { //TODO: Or is it num dims? //size_t firstNDims = lgMap.getResults().size(); size_t firstNDims = lgMap.getNumDims(); + check_reduction = false; auto newMemref = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, - firstNDims, ValueRange(lgOperands)); + firstNDims, ValueRange(lgOperands), check_reduction); if (!legal) @@ -703,9 +710,9 @@ struct AffineForOpRaising : public OpRewritePattern { bool legal = true; size_t firstNDims = lgMap.getNumDims(); + check_reduction = true; auto newMemref = remap_in_affine_dim( - legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands)); - + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), check_reduction); if (!legal) return failure(); @@ -730,10 +737,11 @@ struct AffineForOpRaising : public OpRewritePattern { size_t firstNDims = 0; bool legal = true; + check_reduction = false; auto newMemref = remap_in_affine_dim( legal, rewriter, load.getAffineMap(), load.getMemref(), loop.getInductionVar(), loopSize, firstNDims, - load.getMapOperands()); + load.getMapOperands(), check_reduction); if (!legal) return failure(); @@ -757,10 +765,11 @@ struct AffineForOpRaising : public OpRewritePattern { size_t firstNDims = 0; + check_reduction = true; auto newMemref = remap_in_affine_dim( legal, rewriter, store.getAffineMap(), store.getMemref(), loop.getInductionVar(), loopSize, firstNDims, - store.getMapOperands()); + store.getMapOperands(), check_reduction); if (!legal) return failure(); @@ -784,16 +793,17 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO if linalg generic exists, make this iterator type prepend to the // existing iterators - bool is_parallel = stores_map.size() == 0; + //TODO: Just store check is not sufficient, there has to be a check for + //bool is_parallel = stores_map.size() == 0; // TODO determine if linalg generic, whether to create parallel or reduction by looking at memory patterns of maps if (linalgGenerics.size() == 1) { // determine whether now we write to ourselves } - iteratorTypes.push_back(is_parallel - ? utils::IteratorType::parallel - : utils::IteratorType::reduction); + iteratorTypes.push_back(check_reduction + ? utils::IteratorType::reduction + : utils::IteratorType::parallel); if (linalgGenerics.size() == 1) { for (auto attr : linalgGenerics[0].second.getIteratorTypesArray()) From 98f01194e5af8124b5f4f177c01b568e5a2ef3cb Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 27 Aug 2024 17:34:17 -0700 Subject: [PATCH 23/77] Fix for incorrect for loop dims --- lib/polygeist/Ops.cpp | 21 ++++++++++++++++++++- lib/polygeist/Passes/RaiseToLinalg.cpp | 20 ++++++++++++-------- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index cb486026112b..0f1104f237ba 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5944,7 +5944,26 @@ OpFoldResult mlir::polygeist::SubmapOp::fold(mlir::polygeist::SubmapOp::FoldAdap return nullptr; } + +class DimSubMap final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DimOp op, + PatternRewriter &rewriter) const override { + auto subMapOp = op.getSource().getDefiningOp(); + if (!subMapOp) return failure(); + + auto idx = op.getIndex().getDefiningOp(); + if (!idx) return failure(); + + rewriter.replaceOp(op, subMapOp.getSizes()[idx.value()]); + + return success(); + } +}; + void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 03bda7dbba02..c0bd0fe7feef 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -120,7 +120,7 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { // check_reduction is set true, when passed from store/linalg.generic's output variable. // And it is returned true, only if index was not encountered in oldmap operands and check_reduction was set true. Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, - Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands, bool &check_reduction) { + Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands, Value origmemref, bool &check_reduction) { assert(oldmap_operands.size() == oldmap.getNumSymbols() + oldmap.getNumDims()); //Operands which don't correspond to indices SmallVector operands_without_indices; @@ -193,7 +193,11 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, SmallVector idx_sizes; for (size_t i=0; i(memref_val.getLoc(), memref_val, i)); + if (auto submap = origmemref.getDefiningOp()) + idx_sizes.push_back(submap.getSizes()[i]); + else + llvm_unreachable("Won't reach this case"); + //idx_sizes.push_back(builder.create(origmemref.getLoc(), origmemref, i)); } idx_sizes.push_back(bound); @@ -621,7 +625,7 @@ struct AffineForOpRaising : public OpRewritePattern { int idx = 0; // Iterate over input arguments - for (Value input : lg.getInputs()) { + for (const Value input : lg.getInputs()) { // Is this needed? if (conds.size() != 0) return failure(); @@ -673,7 +677,7 @@ struct AffineForOpRaising : public OpRewritePattern { check_reduction = false; auto newMemref = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, - firstNDims, ValueRange(lgOperands), check_reduction); + firstNDims, ValueRange(lgOperands), input, check_reduction); if (!legal) @@ -688,7 +692,7 @@ struct AffineForOpRaising : public OpRewritePattern { } // Iterate over output arguments - for (Value output : lg.getOutputs()) { + for (const Value output : lg.getOutputs()) { // Is this needed? if (conds.size() != 0) return failure(); @@ -712,7 +716,7 @@ struct AffineForOpRaising : public OpRewritePattern { size_t firstNDims = lgMap.getNumDims(); check_reduction = true; auto newMemref = remap_in_affine_dim( - legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), check_reduction); + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), output, check_reduction); if (!legal) return failure(); @@ -741,7 +745,7 @@ struct AffineForOpRaising : public OpRewritePattern { auto newMemref = remap_in_affine_dim( legal, rewriter, load.getAffineMap(), load.getMemref(), loop.getInductionVar(), loopSize, firstNDims, - load.getMapOperands(), check_reduction); + load.getMapOperands(), load.getMemref(), check_reduction); if (!legal) return failure(); @@ -769,7 +773,7 @@ struct AffineForOpRaising : public OpRewritePattern { auto newMemref = remap_in_affine_dim( legal, rewriter, store.getAffineMap(), store.getMemref(), loop.getInductionVar(), loopSize, firstNDims, - store.getMapOperands(), check_reduction); + store.getMapOperands(), store.getMemref(), check_reduction); if (!legal) return failure(); From 59eec0b59e02756e4e6316c33af5f6722027a0bb Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 4 Sep 2024 21:43:41 -0700 Subject: [PATCH 24/77] Linalg.generic 4 loop cases raised- todo: reduction and some if-else cases failing --- lib/polygeist/Passes/RaiseToLinalg.cpp | 86 ++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 11 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index c0bd0fe7feef..85816fa71a5b 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -131,9 +131,37 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, continue; } assert(i >= firstNDims); - if (v != index) - operands_without_indices.push_back(v); - else + if (v != index) { + // Check if the symbol value is read-only or defined in a scope where it is always visible. + if (auto ba = dyn_cast(v)) { + // check if it dominates the current scope + if (ba.getParentBlock()->getParent()->isAncestor(builder.getBlock()->getParent())) + operands_without_indices.push_back(v); + else { + assert(false); + legal = false; + return nullptr; + } + } else { + auto op = v.getDefiningOp(); + // check if this dominates the current scope + if (op->getParentRegion()->isAncestor(builder.getBlock()->getParent())) { + operands_without_indices.push_back(v); + } else if (isReadOnly(op)) { + // if not, check if it is readnone + // Technically this isn't quite sufficient yet, and does require that the operands to this op are also able to be hoisted, + // but for now we will assume this + auto op2 = builder.clone(*op); + operands_without_indices.push_back(op2->getResult(cast(v).getResultNumber())); + } else { + // if so clone it in the right scope + // otherwise set illegal and don't continue + assert(false); + legal = false; + return nullptr; + } + } + } else dimidx = i; } if((dimidx == -1) && (check_reduction)) @@ -203,10 +231,41 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, legal = true; SmallVector sizes(idx_sizes.size(), mlir::ShapedType::kDynamic); - for (auto sz : idx_sizes) - operands_without_indices.push_back(sz); - // memref + for (auto sz : idx_sizes) { + // Check if the symbol value is read-only or defined in a scope where it is always visible. + if (auto ba = dyn_cast(sz)) { + // check if it dominates the current scope + if (ba.getParentBlock()->getParent()->isAncestor(builder.getBlock()->getParent())) + operands_without_indices.push_back(sz); + else { + llvm::errs() << " value is a non-dominating block arg: " << sz << "\n"; + legal = false; + assert(false); + return nullptr; + } + } else { + auto op = sz.getDefiningOp(); + // check if this dominates the current scope + if (op->getParentRegion()->isAncestor(builder.getBlock()->getParent())) { + operands_without_indices.push_back(sz); + } else if (isReadOnly(op)) { + // if not, check if it is readnone + // Technically this isn't quite sufficient yet, and does require that the operands to this op are also able to be hoisted, + // but for now we will assume this + auto op2 = builder.clone(*op); + operands_without_indices.push_back(op2->getResult(cast(sz).getResultNumber())); + } else { + llvm::errs() << " op is not readonly: " << *op << "\n"; + // if so clone it in the right scope + // otherwise set illegal and don't continue + legal = false; + assert(false); + return nullptr; + } + } + } auto ty = MemRefType::get(sizes, cast(memref_val.getType()).getElementType()); + return builder.create(memref_val.getLoc(), ty, memref_val, operands_without_indices, map2); } @@ -678,8 +737,6 @@ struct AffineForOpRaising : public OpRewritePattern { auto newMemref = remap_in_affine_dim( legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), input, check_reduction); - - if (!legal) return failure(); @@ -775,8 +832,9 @@ struct AffineForOpRaising : public OpRewritePattern { loop.getInductionVar(), loopSize, firstNDims, store.getMapOperands(), store.getMemref(), check_reduction); - if (!legal) + if (!legal) { return failure(); + } auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); affineMaps.push_back(newAffineMap); @@ -786,12 +844,16 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO presently if linalg generic exists, assert there are no load/stores if ((linalgGenerics.size() > 0) && - ((loads.size() != 0) || (stores.size() != 0))) + ((loads.size() != 0) || (stores.size() != 0))) { + assert(false); return failure(); + } // TODO assert only zero or one linalg generic exists - if (!(linalgGenerics.size() == 1 || linalgGenerics.size() == 0)) + if (!(linalgGenerics.size() == 1 || linalgGenerics.size() == 0)) { + //assert(false); return failure(); + } SmallVector iteratorTypes; // TODO if linalg generic exists, make this iterator type prepend to the @@ -879,6 +941,7 @@ struct AffineForOpRaising : public OpRewritePattern { for (auto op : term->getOperands()) { toreturn.push_back(map.lookup(op)); } + //llvm::errs() << genOp->getParentOfType() << "\n"; rewriter.eraseOp(genOp); } @@ -891,6 +954,7 @@ struct AffineForOpRaising : public OpRewritePattern { rewriter.setInsertionPointToEnd(blk); rewriter.create(loop.getLoc(), toreturn); + auto func = loop->getParentOfType(); rewriter.eraseOp(loop); // return success! return success(); From a363f1362f5e016ae2beba9e42a99be89e9e0302 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 17 Sep 2024 17:40:17 -0700 Subject: [PATCH 25/77] Adding test case for all passing raising and lowering, example case of debufferizing added which works for tiling and fusion --- .../linalg_debufferize_tile_fusion.mlir | 133 ++ test/polygeist-opt/linalgraise.mlir | 1398 +++++++++-------- 2 files changed, 854 insertions(+), 677 deletions(-) create mode 100644 test/polygeist-opt/linalg_debufferize_tile_fusion.mlir diff --git a/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir b/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir new file mode 100644 index 000000000000..fb08f31190bb --- /dev/null +++ b/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --one-shot-bufferize="bufferize-function-boundaries" --func-bufferize --tensor-bufferize --finalizing-bufferize --convert-linalg-to-affine-loops --raise-scf-to-affine -split-input-file -verify-diagnostics | FileCheck %s +// To test bufferization : pva-opt %s -test-transform-dialect-interpreter --one-shot-bufferize="bufferize-function-boundaries test-analysis-only print-conflicts" +#map1 = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +//#trait_conv = { +// indexing_maps = [ +// affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>, +// affine_map<(d0, d1, d2, d3) -> (d2, d3)>, +// affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// ], +// iterator_types = ["parallel", "parallel", "reduction", "reduction"] +//} +// +////Remember to tile basd on output +//func.func @conv(%A : tensor<130x130xf32>, %B : tensor<3x3xf32>, +// %C : tensor<128x128xf32>) -> tensor<128x128xf32> { +// %1 = linalg.generic #trait_conv +// ins(%A, %B : tensor<130x130xf32>, +// tensor<3x3xf32>) +// outs(%C : tensor<128x128xf32>) { +// ^bb0(%a: f32, %b: f32, %c: f32) : +// %d = arith.mulf %a, %b: f32 +// %e = arith.addf %c, %d: f32 +// linalg.yield %e : f32 +// } -> tensor<128x128xf32> +// return %1 : tensor<128x128xf32> +//} +memref.global @out : memref<512x64xi32> = uninitialized +memref.global @rhs : memref<64x64xi32> = uninitialized +memref.global @filter : memref<4x4xi32> = uninitialized +memref.global @im : memref<515x67xi32> = uninitialized +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c512 = arith.constant 512 : index +// %c64 = arith.constant 64 : index +// %c4 = arith.constant 4 : index +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im : memref<515x67xi32> +// %1 = memref.get_global @filter : memref<4x4xi32> +// %2 = memref.get_global @out : memref<512x64xi32> +// %rhs_memref = memref.get_global @rhs : memref<64x64xi32> +// %4 = bufferization.to_tensor %0 : memref<515x67xi32> +// %5 = bufferization.to_tensor %1 : memref<4x4xi32> +// %x = tensor.empty() : tensor<512x64xi32> +// %out = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%4, %5 : tensor<515x67xi32>, tensor<4x4xi32>) outs(%x : tensor<512x64xi32>) { +// ^bb0(%in: i32, %in_0: i32, %out: i32): +// %6 = arith.muli %in, %in_0 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7 : i32 +// } -> tensor<512x64xi32> + +// %materialize = bufferization.to_memref %out : memref<512x64xi32> +// memref.copy %materialize, %2 : memref<512x64xi32> to memref<512x64xi32> + +// %conv_out = bufferization.to_tensor %2 : memref<512x64xi32> +// %rhs = bufferization.to_tensor %rhs_memref : memref<64x64xi32> +// %y = tensor.empty() : tensor<512x64xi32> +// %matmul = linalg.matmul ins(%conv_out, %rhs: tensor<512x64xi32>, tensor<64x64xi32>) +// outs(%y: tensor<512x64xi32>) -> tensor<512x64xi32> +// %materialize2 = bufferization.to_memref %matmul : memref<512x64xi32> +// memref.copy %materialize2, %2 : memref<512x64xi32> to memref<512x64xi32> +// return %c0_i32 : i32 +// } + +func.func @main_opt() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %rhs_memref = memref.get_global @rhs : memref<64x64xi32> + %4 = bufferization.to_tensor %0 : memref<515x67xi32> + %5 = bufferization.to_tensor %1 : memref<4x4xi32> + %x = tensor.empty() : tensor<512x64xi32> + %conv_out = bufferization.to_tensor %2 : memref<512x64xi32> + %rhs = bufferization.to_tensor %rhs_memref : memref<64x64xi32> + %y = tensor.empty() : tensor<512x64xi32> + %out = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%4, %5 : tensor<515x67xi32>, tensor<4x4xi32>) outs(%x : tensor<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } -> tensor<512x64xi32> + %matmul = linalg.matmul ins(%out, %rhs: tensor<512x64xi32>, tensor<64x64xi32>) + outs(%y: tensor<512x64xi32>) -> tensor<512x64xi32> + + %materialize2 = bufferization.to_memref %matmul : memref<512x64xi32> + memref.copy %materialize2, %2 : memref<512x64xi32> to memref<512x64xi32> + return %c0_i32 : i32 +} + +// transform.sequence failures(propagate) { +// ^bb0(%arg0 : !transform.any_op): +// %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op +// //Note that these represent the outer dimension first for tiling +// %1,%2,%3 = transform.structured.tile_using_for %0 [32,32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +// transform.yield +// } + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op) : + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %generic = transform.structured.match ops{["linalg.matmul","linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %conv, %mul = transform.split_handle %generic + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It + // produces a handle to the loop generated during tiling. + %tiled_mul, %loop = + transform.structured.tile_using_forall %mul tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one by one. This requires the operation that is being fused to + // define the value used within the loop, so the order of such fusions is + // important. We could also use "transform.merge_handles" to obtain a single + // handle to all operations and give it to `fuse_into_containing_op` that + // would take care of the ordering in this case. + %conv_fused, %loop_0 = + transform.structured.fuse_into_containing_op %conv into %loop + : (!transform.any_op, !transform.any_op) + -> (!transform.any_op, !transform.any_op) + + + transform.yield +} + +// ----- \ No newline at end of file diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index 5c77b25a10df..b4bb5687ac35 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -1,604 +1,604 @@ -////// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s -//// -//module { -// func.func @main0(%12 : i1, %18 : memref<32xf32> ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %19 = memref.alloca() : memref<32xf32> -// scf.if %12 { -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref<32xf32> -// affine.store %ld, %19[%arg4] : memref<32xf32> -// } -// } -// return -// } -// -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// scf.if %12 { -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// affine.store %ld, %19[%arg4] : memref -// } -// } -// return -// } -// -// -// func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// scf.if %12 { -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[3 * %arg4] : memref -// %ld2 = affine.load %18[0] : memref -// %fadd = arith.addf %ld, %ld2 : f32 -// affine.store %fadd, %19[%arg4 + 17] : memref -// } -// } -// return -// } -// -//} -// -//// CHECK: #map = affine_map<(d0) -> (d0)> -//// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { -//// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index -//// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index -//// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index -//// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index -//// CHECK-NEXT: scf.if %[[arg0]] { -//// TODO note that presently we do not ensure that the memrefs are sliced to the right size as the space requires -//// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg2 : memref) outs(%alloca : memref) { -//// CHECK-NEXT: ^bb0(%in: f32, %out: f32): -//// CHECK-NEXT: linalg.yield %in : f32 -//// CHECK-NEXT: } -//// CHECK-NEXT: } -//// CHECK-NEXT: } -// -////constant-access -//module @constant_access{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %ci324 = arith.constant 4.0 : f32 -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %mul = arith.mulf %ld, %ci324 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -////constant-mem-access -//module @constant_mem_access{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 4 to 17 step 2 { -// %ld = affine.load %18[3*%arg4] : memref -// %ld2 = affine.load %18[%c4] : memref -// %mul = arith.mulf %ld, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -////without-if -//module @no_if{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// affine.store %ld, %19[%arg4] : memref -// } -// return -// } -//} -// -////arith.mul -//module @arith_mul{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %mul = arith.mulf %ld, %ld : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -////arith.add -//module @arith_add{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg4] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %add = arith.addf %ld1, %ld2 : f32 -// %mul = arith.mulf %add, %add : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -////Conditional arith -//module @cond_arith{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %if = scf.if %12 -> f32 { -// %mul = arith.mulf %ld, %ld : f32 -// scf.yield %mul : f32 -// } else { -// scf.yield %ld : f32 -// } -// affine.store %if, %19[%arg4] : memref -// } -// return -// } -//} -// -////TODO: reduction -//module @reduction{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// %sum_0 = arith.constant 0.0 : f32 -// %red = affine.for %arg4 = 0 to 17 step 1 iter_args(%sum_iter = %sum_0) -> f32 { -// %ld1 = affine.load %18[%arg4] : memref -// %sum_next = arith.addf %sum_iter, %ld1 : f32 -// affine.yield %sum_next : f32 -// } -// affine.store %red, %19[0] : memref -// return -// } -//} -// -////TODO: Conditional store-1 -//module @cond_store_1 { -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %mul = arith.mulf %ld, %ld : f32 -// scf.if %12 { -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////TODO: Conditional store-2 -//module @cond_store_2{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref ) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// scf.if %12 { -// %mul = arith.mulf %ld, %ld : f32 -// affine.store %mul, %19[%arg4] : memref -// } else { -// affine.store %ld, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////Parallel for -//module @parallel_for{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// %ld = affine.load %18[%arg4] : memref -// %mul = arith.mulf %ld, %ld : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg4] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %add = arith.addf %ld1, %ld2 : f32 -// %mul = arith.mulf %add, %add : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// return -// } -//} -// -//////Fors inside for -//module @for_within_for{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_within_for_2{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3+2*%arg4] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_within_for_3{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3+2*%arg4] : memref -// %ld2 = affine.load %18[%arg3] : memref -// %ld3 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// %mul2 = arith.mulf %mul, %ld3 : f32 -// affine.store %mul2, %19[%arg4] : memref -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_within_for_4{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg4+2*%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } -//} +//// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s // -////Fors no-loop dependency -//module @for_no_loop_dependency{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 15 { -// %ld1 = affine.load %18[0] : memref -// affine.store %ld1, %19[0] : memref -// } -// return -// } -//} -////Fors no-loop dependency -//module @for_2_levels_no_loop_dependency{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg4 = 0 to 17 { -// affine.for %arg3 = 0 to 15 { -// %ld1 = affine.load %18[%arg4] : memref -// affine.store %ld1, %19[%arg4] : memref -// } +// module { +// func.func @main0(%12 : i1, %18 : memref<32xf32> ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %19 = memref.alloca() : memref<32xf32> +// scf.if %12 { +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref<32xf32> +// affine.store %ld, %19[%arg4] : memref<32xf32> +// } // } -// return -// } -//} -////Fors inside for -//module @for_3_levels_0{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 15 { -// affine.for %arg4 = 0 to 17 { -// affine.for %arg5 = 0 to 21 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg5] : memref -// } -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_3_levels_1{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg5 = 0 to 21 { -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_3_levels_2{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// affine.for %arg5 = 0 to 21 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %ld3 = affine.load %23[%arg5] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// %mul2 = arith.mulf %mul, %ld3 : f32 -// affine.store %mul2, %19[%arg4] : memref -// } -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_3_levels_3{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// affine.for %arg5 = 0 to 21 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %ld3 = affine.load %20[%arg5] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// %mul2 = arith.mulf %mul, %ld3 : f32 -// affine.store %mul2, %19[%arg4] : memref -// } -// } -// } -// return -// } -//} -// -////Fors inside for -//module @for_3_levels_4{ -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %21 = arith.muli %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 21 { -// affine.for %arg4 = 0 to 17 { -// affine.for %arg5 = 0 to 21 { -// %ld1 = affine.load %18[%arg3+4*%arg4+3] : memref -// %ld2 = affine.load %20[7*%arg4+%arg5+2] : memref -// %ld3 = affine.load %20[%arg5+2*%arg3] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// %mul2 = arith.mulf %mul, %ld3 : f32 -// affine.store %mul2, %19[%arg4] : memref -// } -// } -// } -// return -// } -//} -// -////Intermediate raising -//#map = affine_map<(d0)[s0] -> (s0)> -//#map1 = affine_map<(d0) -> (d0)> -//module @for_within_for2 { -// func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { -// %c17 = arith.constant 17 : index -// %c4 = arith.constant 4 : index -// %0 = arith.index_cast %arg1 : i32 to index -// %1 = arith.muli %0, %c4 : index -// %2 = arith.divui %1, %c4 : index -// %alloca = memref.alloca(%2) : memref -// affine.for %arg4 = 0 to 21 { -// %3 = "polygeist.submap"(%arg2, %arg4, %c17) <{map = #map}> : (memref, index, index) -> memref -// %4 = "polygeist.submap"(%arg3, %c17) <{map = #map1}> : (memref, index) -> memref -// %5 = "polygeist.submap"(%alloca, %c17) <{map = #map1}> : (memref, index) -> memref -// linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { -// ^bb0(%in: f32, %in_0: f32, %out: f32): -// %6 = arith.mulf %in, %in_0 : f32 -// linalg.yield %6 : f32 -// } -// } -// return -// } -//} -// -////Parallel fors inside for -//module @parallel_fors_inside_for { -// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { -// %c0 = arith.constant 0 : index -// %c4 = arith.constant 4 : index -// %c1 = arith.constant 1 : index -// %15 = arith.index_cast %14 : i32 to index -// %16 = arith.muli %15, %c4 : index -// %17 = arith.divui %16, %c4 : index -// %19 = memref.alloca(%17) : memref -// affine.for %arg3 = 0 to 17 { -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %mul = arith.mulf %ld1, %ld2 : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// affine.for %arg4 = 0 to 17 { -// %ld1 = affine.load %18[%arg3] : memref -// %ld2 = affine.load %20[%arg4] : memref -// %add = arith.addf %ld1, %ld2 : f32 -// %mul = arith.mulf %add, %add : f32 -// affine.store %mul, %19[%arg4] : memref -// } -// } -// return -// } +// return +// } + + // func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + // %c0 = arith.constant 0 : index + // %c4 = arith.constant 4 : index + // %c1 = arith.constant 1 : index + // %15 = arith.index_cast %14 : i32 to index + // %16 = arith.muli %15, %c4 : index + // %17 = arith.divui %16, %c4 : index + // %19 = memref.alloca(%17) : memref + // scf.if %12 { + // affine.for %arg4 = 0 to 17 { + // %ld = affine.load %18[%arg4] : memref + // affine.store %ld, %19[%arg4] : memref + // } + // } + // return + // } + + + // func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { + // %c0 = arith.constant 0 : index + // %c4 = arith.constant 4 : index + // %c1 = arith.constant 1 : index + // %15 = arith.index_cast %14 : i32 to index + // %16 = arith.muli %15, %c4 : index + // %17 = arith.divui %16, %c4 : index + // %19 = memref.alloca(%17) : memref + // scf.if %12 { + // affine.for %arg4 = 0 to 17 { + // %ld = affine.load %18[3 * %arg4] : memref + // %ld2 = affine.load %18[0] : memref + // %fadd = arith.addf %ld, %ld2 : f32 + // affine.store %fadd, %19[%arg4 + 17] : memref + // } + // } + // return + // } + //} -// + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { +// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index +// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index +// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index +// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index +// CHECK-NEXT: scf.if %[[arg0]] { +// TODO note that presently we do not ensure that the memrefs are sliced to the right size as the space requires +// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg2 : memref) outs(%alloca : memref) { +// CHECK-NEXT: ^bb0(%in: f32, %out: f32): +// CHECK-NEXT: linalg.yield %in : f32 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +//constant-access +module @constant_access{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %ci324 = arith.constant 4.0 : f32 + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %mul = arith.mulf %ld, %ci324 : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} + +//constant-mem-access +module @constant_mem_access{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 4 to 17 step 2 { + %ld = affine.load %18[3*%arg4] : memref + %ld2 = affine.load %18[%c4] : memref + %mul = arith.mulf %ld, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} + +//without-if +module @no_if{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + affine.store %ld, %19[%arg4] : memref + } + return + } +} + +//arith.mul +module @arith_mul{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %mul = arith.mulf %ld, %ld : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} + +//arith.add +module @arith_add{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg4] : memref + %ld2 = affine.load %20[%arg4] : memref + %add = arith.addf %ld1, %ld2 : f32 + %mul = arith.mulf %add, %add : f32 + affine.store %mul, %19[%arg4] : memref + } + return + } +} + +//Conditional arith +module @cond_arith{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %if = scf.if %12 -> f32 { + %mul = arith.mulf %ld, %ld : f32 + scf.yield %mul : f32 + } else { + scf.yield %ld : f32 + } + affine.store %if, %19[%arg4] : memref + } + return + } +} + +//TODO: reduction +module @reduction{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + %sum_0 = arith.constant 0.0 : f32 + %red = affine.for %arg4 = 0 to 17 step 1 iter_args(%sum_iter = %sum_0) -> f32 { + %ld1 = affine.load %18[%arg4] : memref + %sum_next = arith.addf %sum_iter, %ld1 : f32 + affine.yield %sum_next : f32 + } + affine.store %red, %19[0] : memref + return + } +} + +//TODO: Conditional store-1 +module @cond_store_1 { + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + %mul = arith.mulf %ld, %ld : f32 + scf.if %12 { + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//TODO: Conditional store-2 +module @cond_store_2{ + func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + %ld = affine.load %18[%arg4] : memref + scf.if %12 { + %mul = arith.mulf %ld, %ld : f32 + affine.store %mul, %19[%arg4] : memref + } else { + affine.store %ld, %19[%arg4] : memref + } + } + return + } +} + +// //Parallel for +// module @parallel_for{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ld : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg4] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +// } + +//Fors inside for +module @for_within_for{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//Fors inside for +module @for_within_for_2{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3+2*%arg4] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//Fors inside for +module @for_within_for_3{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3+2*%arg4] : memref + %ld2 = affine.load %18[%arg3] : memref + %ld3 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + return + } +} + +//Fors inside for +module @for_within_for_4{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg4+2*%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + return + } +} + +//Fors no-loop dependency +module @for_no_loop_dependency{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 15 { + %ld1 = affine.load %18[0] : memref + affine.store %ld1, %19[0] : memref + } + return + } +} +//Fors no-loop dependency +module @for_2_levels_no_loop_dependency{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + affine.for %arg3 = 0 to 15 { + %ld1 = affine.load %18[%arg4] : memref + affine.store %ld1, %19[%arg4] : memref + } + } + return + } +} +//Fors inside for +module @for_3_levels_0{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 15 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg5] : memref + } + } + } + return + } +} + +//Fors inside for +module @for_3_levels_1{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg5 = 0 to 21 { + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + } + return + } +} + +//Fors inside for +module @for_3_levels_2{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %ld3 = affine.load %23[%arg5] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + } + return + } +} + +//Fors inside for +module @for_3_levels_3{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %ld3 = affine.load %20[%arg5] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + } + return + } +} + +//Fors inside for +module @for_3_levels_4{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3+4*%arg4+3] : memref + %ld2 = affine.load %20[7*%arg4+%arg5+2] : memref + %ld3 = affine.load %20[%arg5+2*%arg3] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + } + return + } +} + +//Intermediate raising +#map = affine_map<(d0)[s0] -> (s0)> +#map1 = affine_map<(d0) -> (d0)> +module @for_within_for2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.for %arg4 = 0 to 21 { + %3 = "polygeist.submap"(%arg2, %arg4, %c17) <{map = #map}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17) <{map = #map1}> : (memref, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17) <{map = #map1}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + } + return + } +} + +// //Parallel fors inside for +// module @parallel_fors_inside_for { +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 17 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +// } + //matrix-mul iter arg module @matmul_1 { memref.global @out : memref<32x8xi32> = uninitialized @@ -625,64 +625,35 @@ module @matmul_1 { } } -////matrix-mul extra load-store variant -//module @matmul_2 { -// memref.global @out : memref<128x32xi32> = uninitialized -// memref.global @im2 : memref<64x32xi32> = uninitialized -// memref.global @im1 : memref<128x64xi32> = uninitialized -// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c0_i32 = arith.constant 0 : i32 -// %0 = memref.get_global @im1 : memref<128x64xi32> -// %1 = memref.get_global @im2 : memref<64x32xi32> -// %2 = memref.get_global @out : memref<128x32xi32> -// affine.for %arg0 = 0 to 128 { -// affine.for %arg1 = 0 to 32 { -// affine.for %arg2 = 0 to 64 { -// %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> -// %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> -// %5 = arith.muli %3, %4 : i32 -// %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> -// %7 = arith.addi %6, %5 : i32 -// affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> -// } -// } -// } -// return %c0_i32 : i32 -// } -//} +//matrix-mul extra load-store variant + module @matmul_2 { + memref.global @out : memref<128x32xi32> = uninitialized + memref.global @im2 : memref<64x32xi32> = uninitialized + memref.global @im1 : memref<128x64xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im1 : memref<128x64xi32> + %1 = memref.get_global @im2 : memref<64x32xi32> + %2 = memref.get_global @out : memref<128x32xi32> + affine.for %arg0 = 0 to 128 { + affine.for %arg1 = 0 to 32 { + affine.for %arg2 = 0 to 64 { + %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> + %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> + } + } + } + return %c0_i32 : i32 + } + } -////conv (with inner loop accumulate) -////How to deal with IR in outer loops as well? -//module @conv_1{ -// memref.global @out : memref<512x64xi32> = uninitialized -// memref.global @filter : memref<4x4xi32> = uninitialized -// memref.global @im : memref<515x67xi32> = uninitialized -// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c0_i32 = arith.constant 0 : i32 -// %0 = memref.get_global @im : memref<515x67xi32> -// %1 = memref.get_global @filter : memref<4x4xi32> -// %2 = memref.get_global @out : memref<512x64xi32> -// affine.for %arg0 = 0 to 512 { -// affine.for %arg1 = 0 to 64 { -// %3 = affine.for %arg2 = 0 to 4 iter_args(%arg3 = %c0_i32) -> (i32) { -// %4 = affine.for %arg4 = 0 to 4 iter_args(%arg5 = %arg3) -> (i32) { -// %5 = affine.load %0[%arg0 + %arg2, %arg1 + %arg4] : memref<515x67xi32> -// %6 = affine.load %1[%arg2, %arg4] : memref<4x4xi32> -// %7 = arith.muli %5, %6 : i32 -// %8 = arith.addi %arg5, %7 : i32 -// affine.yield %8 : i32 -// } -// affine.yield %4 : i32 -// } -// affine.store %3, %2[%arg0, %arg1] : memref<512x64xi32> -// } -// } -// return %c0_i32 : i32 -// } -//} -// -//conv (direct store) -module @conv_2 { +//conv (with inner loop accumulate) +//How to deal with IR in outer loops as well? +module @conv_1{ memref.global @out : memref<512x64xi32> = uninitialized memref.global @filter : memref<4x4xi32> = uninitialized memref.global @im : memref<515x67xi32> = uninitialized @@ -693,23 +664,24 @@ module @conv_2 { %2 = memref.get_global @out : memref<512x64xi32> affine.for %arg0 = 0 to 512 { affine.for %arg1 = 0 to 64 { - affine.for %arg2 = 0 to 4 { - affine.for %arg3 = 0 to 4 { - %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> - %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> - %5 = arith.muli %3, %4 : i32 - %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> - %7 = arith.addi %6, %5 : i32 - affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + %3 = affine.for %arg2 = 0 to 4 iter_args(%arg3 = %c0_i32) -> (i32) { + %4 = affine.for %arg4 = 0 to 4 iter_args(%arg5 = %arg3) -> (i32) { + %5 = affine.load %0[%arg0 + %arg2, %arg1 + %arg4] : memref<515x67xi32> + %6 = affine.load %1[%arg2, %arg4] : memref<4x4xi32> + %7 = arith.muli %5, %6 : i32 + %8 = arith.addi %arg5, %7 : i32 + affine.yield %8 : i32 } + affine.yield %4 : i32 } + affine.store %3, %2[%arg0, %arg1] : memref<512x64xi32> } } return %c0_i32 : i32 } -} +} -module @submap_test { +module @conv_1_reduction_test{ memref.global @out : memref<512x64xi32> = uninitialized memref.global @filter : memref<4x4xi32> = uninitialized memref.global @im : memref<515x67xi32> = uninitialized @@ -718,17 +690,89 @@ module @submap_test { %0 = memref.get_global @im : memref<515x67xi32> %1 = memref.get_global @filter : memref<4x4xi32> %2 = memref.get_global @out : memref<512x64xi32> - affine.for %arg2 = 0 to 4 { - affine.for %arg3 = 0 to 4 { - %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> - %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> - %5 = arith.muli %3, %4 : i32 - %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> - %7 = arith.addi %6, %5 : i32 - affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> - } - } + %3 = affine.for %arg2 = 0 to 4 iter_args(%arg3 = %c0_i32) -> (i32) { + %4 = affine.for %arg4 = 0 to 4 iter_args(%arg5 = %arg3) -> (i32) { + %5 = affine.load %0[%arg0 + %arg2, %arg1 + %arg4] : memref<515x67xi32> + %6 = affine.load %1[%arg2, %arg4] : memref<4x4xi32> + %7 = arith.muli %5, %6 : i32 + %8 = arith.addi %arg5, %7 : i32 + affine.yield %8 : i32 + } + affine.yield %4 : i32 + } + affine.store %3, %2[%arg0, %arg1] : memref<512x64xi32> return %c0_i32 : i32 } -} - \ No newline at end of file +} + +//conv (direct store) + module @conv_2 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 64 { + affine.for %arg2 = 0 to 4 { + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + } + } + } + } + return %c0_i32 : i32 + } + } + + module @conv_loop1_test { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + } + return %c0_i32 : i32 + } + } + + module @submap_test { + memref.global @out : memref<511x64xi32> = uninitialized + memref.global @filter : memref<5x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main(%arg0 : index, %arg1 : index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<5x4xi32> + %2 = memref.get_global @out : memref<511x64xi32> + affine.for %arg2 = 0 to 5 { + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<5x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<511x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<511x64xi32> + } + } + return %c0_i32 : i32 + } + } From 814ca51fd2f890df4bd6e35ab33e77fa7b994ce9 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 11 Oct 2024 20:24:27 -0700 Subject: [PATCH 26/77] Added pass remove iter args from scf; Added psuedo code for submap canonicalize --- include/polygeist/Passes/Passes.h | 1 + include/polygeist/Passes/Passes.td | 9 + lib/polygeist/Ops.cpp | 258 +++++++++++++ lib/polygeist/Passes/CMakeLists.txt | 1 + lib/polygeist/Passes/RaiseToLinalg.cpp | 153 +++++++- .../linalg_debufferize_tile_fusion.mlir | 38 +- test/polygeist-opt/linalgraise.mlir | 351 +++++++++++++++++- 7 files changed, 743 insertions(+), 68 deletions(-) diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index 92c5812e8c4c..96ecf5b32003 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -32,6 +32,7 @@ std::unique_ptr createOpenMPOptPass(); std::unique_ptr createCanonicalizeForPass(); std::unique_ptr createRaiseSCFToAffinePass(); std::unique_ptr createRaiseAffineToLinalgPass(); +std::unique_ptr createRemoveSCFIterArgsPass(); std::unique_ptr createCPUifyPass(StringRef method = ""); std::unique_ptr createBarrierRemovalContinuation(); std::unique_ptr detectReductionPass(); diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index fc5b36aa9caf..0d3116f82c71 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -151,6 +151,15 @@ def SCFRaiseToAffine : Pass<"raise-scf-to-affine"> { ]; } +def RemoveSCFIterArgs : Pass<"remove-scf-iter-args"> { + let summary = "Remove scf iter args"; + let constructor = "mlir::polygeist::createRemoveSCFIterArgsPass()"; + let dependentDialects = [ + "affine::AffineDialect", + "scf::SCFDialect", + ]; +} + def AffineRaiseToLinalg : Pass<"raise-affine-to-linalg"> { let summary = "Raise affine to linalg"; let constructor = "mlir::polygeist::createRaiseAffineToLinalgPass()"; diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 0f1104f237ba..bfe1a6eab2d7 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5733,6 +5733,263 @@ struct MulDivMul : public OpRewritePattern { } }; +//struct SubMapOpCanonicalize : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; +// LogicalResult matchAndRewrite(linalg::GenericOp gen, +// PatternRewriter &rewriter) const override { +// +// // Canonicalization 1 linalg.generic of map of submap. -> linalg.generic of map of submap +// //. iff the submap's affine map != identity +// //. replace inner affine map with composition +// +// +// // Canonicalizeation 3: submap which only sets bounds, of an input memref with the same bounds -> noop / cast +// +// +// // Canonicalization 1.5 (mix of 1/2) +// //. linalg_map = identity a[i,j,x,y] -> u[i+x][j+y] +// //. linalg_map = [i,j,x,y]->(i+x,j+y) a[i,j] -> u[i,j]. # but still keeping the upper loop limit +// //. 1 +// +// +// // a[i] -> x[] +// +// // a[1] -> x[] +// // a[2] -> x[] +// +// +// // a[i,j] = x[map(i,j)]. ; the subbmap op +// //a[i+x][j+y] : submap defines iteration var 0 goes from 0 ... A0. and var 1 goes from 0 ... A1 +// //b[x][y] +// //c[i+x][j+y] +// // here we have 4 iteration variables that linalg is doing i, j, x, y +// // for (i : ...) +// //. for (j : ...) +// //. for (x : ...) +// //. for (y : ...) +// // c[i+x][j+y] += a[i+x][j+y] * b[x][y] +// +// // a[i+x][j+y] +// // c[i+x][j+y] +// // for (i : ...) +// //. for (j : ...) +// //. for (x : ...) +// //. for (y : ...) +// // c[i+x][j+y] += a[i+x][j+y] +// +// +// //x[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// //b[x][y] +// //c[i+x][j+y] +// +// +// // requirement here, is that all linalg.generic loop bounds must be solvable after replacement +// // for example, this would not be permissible +// // a[i] -> x[]. ; a = submap memref -> memref<100xf32> +// // out[] +// +// // This cannot be replaced since now the linalg generic iteration variable i cannot be solved for +// +// +// +// for (auto &&[op, opmap] : gen.getInputsAndMaps()) { +// if (auto submap = op.getDefiningOp()) { +// bool solvable = false; +// +// /// Cannoicalization 2: index removal +// //. x[i, j] -> v[i]. can we get rid of j? +// //. Are input indices defined by other ops, and if so, can we simplify +// //. 1) Take all other input memrefs +// // 2) Determine all solvable indices from those input memrefs +// //. For each index which is solvable from 2) +// // if it can either be removed from the submap, or combined with another index in the submap, +// // remove it from the submap +// +// SmallVector exprs; +// for (auto [op2, map] : gen.getInputAndMaps()) { +// if (op != op2) { +// for (auto expr : map.getAffineExprs()) { +// exprs.push_back(expr); +// } +// } +// } +// for (auto [op2, map] : gen.getOutputAndMaps()) { +// if (op != op2) { +// for (auto expr : map.getAffineExprs()) { +// exprs.push_back(expr); +// } +// } +// } +// SmallSet solvable; +// linalg.determineSolvableIndices(solvable, exprs); +// +// SmallSet notsolvable = allvariables - solvable; +// +// //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// //b[x][y] +// //c[i+x][j+y] +// // Supose we're solving for a +// // Here exprs would contain all the affineexprs from b and c. (aka inputs - {x}) +// +// // {x, y, i+x, j+y} +// // Running a solver allows us to uniquely solve for all of, x, y, i, and j with these expressoin +// // In this case we can attempt to remove dependence on x, y, i, j +// +// // If however we had +// //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// //b[x][y] +// //c[i+x][y] +// // we would solve with {x, y, i+x, y} +// // Running a solver we would be able to sole for {x, y, i} but not solve for j +// // In this case we can attempt to remove dependence on x, y, i, but not on j +// +// // let's take easiest one where a is just broadcasting a constant to all input indices +// // a = submap (m,n) -> u[] +// // a[i+x, j+y] For all input indices which are uniquely solvable, here that is both +// //. index 0 = i + x +// //. and index 1 = j + y +// // set the input map to compose with the submap's affine map +// +// +// /// Easy special case +// if (notsolvable.size() == 0) { +// +// +// replace opmap with submap.compose(opmap) taking into account the the ConstantIntRanges +// // Easy case +// } +// +// // We now have two maps with different meanings +// // Let |N| be the number of loop variables in the linalg.generic +// // Let |M| be length(submap.getType().getShape()) +// // Let |Q| be length(submap.getInput().getType().getShape()), number of dimensions of input operand to the submap +// +// // opmap from the linalg.generic which takes linalg.generic loop indices |N| -> inputs to the submap op. |M| +// +// // submap.map. submap op. which takes input indices |M|. -> indices for the corresponing base memref |Q| +// +// // Example +// +// //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// //b[x][y] +// //c[i+x][j+y] +// +// // a = submap (w,p) -> u[c + 2 * p] +// +// // %c = myop.constant() +// // %a = submap a[w, p] -> u[%c + 2 * p] +// //. linalg.generic %a %b %c a.map (x,y,i,j) -> a[x+i,y+j] { +// // } +// +// // N = 4 = |{i,j,x,u}| +// // M = 2 = dim(a) . a is 2 dims +// // Q = 1. dim(u) +// +// SmallVector newLinalgExprs; +// SmallVector newSubmapExprs; +// +// SmallVector legalIndices; +// // We iterate for all |M| expressions of the opmap +// for (auto &&[i, linalgexpr] : llvm::enumerate(opmap.getExprs())) { +// // We must retain the indexing for variables which are functions +// // of the inputs which have a defining index. +// bool legal = true; +// for (auto var : notsolvable) { +// if (linalgexpr.isFunctionOf(var)) { +// legal = false; +// // we can pop this from the not solvable since now this index will define +// // the value of var for future iterations. +// // But doing so requires proving it is not a linear combination of previously +// // visited linalgexpr's, so we'll defer this for a later optimization +// // notsolvable.pop(var); +// } +// } +// +// if (legal) +// legalIndices.push_back(i); +// } +// +// // The non-special case version +// // j is not solvable +// //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// //b[x][y] +// //c[i+x][y] +// +// // because j is not solvable we cannot move any expressions depending on j (in this case p depends on j) +// //. and the underlying sub expressions depending j, in this case via p are: +// // a[1] = w + 4 and a[2] = w + 7 +// // define a(w,p) -> u[c + 2 * p, w + 4, w + 7] +// +// // with the general case optimization v0. [just moving expressions up] +// +// //a2[map(i+x, j+y), i + x + 4, i + x + 7] pass in the outermost one with correspondidng composed maps +// //b[x][y] +// //c[i+x][y] +// +// // define a2(w, p) -> u[c + 2 * p] +// +// // with the general case optimization v1. [just eliminating unnecessary indices] +// +// //a2[map(j+y), i + x + 4, i + x + 7] pass in the outermost one with correspondidng composed maps +// //b[x][y] +// //c[i+x][y] +// +// // define a2(p) -> u[c + 2 * p] +// +// // So this optimization generally moves expression from the submap into the linalg map +// // and it it also removes unnecessary indices into the submap +// +// +// // If the entire submap is legal to inline, the solution is simple, replace the linalg +// // map with itself composed with the submap, and replace the original submap with the identity op +// if (legalIndices.size() == opmap.getExprs().size()) { +// // Note, it isn't 100% as simple as below since we still need to retain any constant op's in the +// // new submap op below, since linalg.generic doesn't support constant value's for the indexing, as far +// // as I (wmoses) know? +// newLinalgExprs = opmap.compose(submap.getMap()).getExprs(); +// newSubmapExprs = Affine::getIdentityMap(submap.getOperand().getShape().size()).getExprs(); +// } else { +// SmallVector illegalIndices = allIndices - legalIndices; +// +// // We can alternatively re-index maps which are solely functions of legal indices. +// for (auto &&[i, submapexpr] : llvm::enumerate(submap.getAffineMap().getExprs())) { +// if (submapexpr is a function of any illegal indicies) { +// // we need to keep this as a submap expr (though re-indexed on the new number of exprs) +// newSubmapExprs.push_back(submapexpr.reindex()); +// } else { +// // this index can be completely solved for with other inputs, let's move the expression from +// // a submap expression into a linalg.generic map expression. +// newLinalgExprs.push_back(opmap.compose(submapexpr)); +// newSubmapExprs.push_back(Affine::getIdentity()); +// } +// } +// } +// +// if (solvable) { +// // replace the input to the generic with the input to the submap, and the new map +// return success(); +// } +// } +// } +// +// +// +// for (auto op : gen.getOutputs()) { +// if (auto submap = op.getDefiningOp()) { +// bool solvable = false; +// if (solvable) { +// do the thing +// // replace the input to the generic with the input to the submap, and the new map +// return success(); +// } +// } +// } +// +// +// return failure(); +// } +//}; + static llvm::cl::opt BufferElim("enable-buffer-elim", llvm::cl::init(true), llvm::cl::desc("Enable buffer elimination")); @@ -5965,5 +6222,6 @@ class DimSubMap final : public OpRewritePattern { void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { + //results.insert(context); results.insert(context); } diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index d6947a1931c5..bcc6de07193d 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms OpenMPOpt.cpp BarrierRemovalContinuation.cpp RaiseToAffine.cpp + RemoveScfIterArgs.cpp RaiseToLinalg.cpp ParallelLower.cpp TrivialUse.cpp diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 85816fa71a5b..dac831af5477 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -25,13 +25,6 @@ using namespace polygeist; using namespace affine; using namespace linalg; -namespace { -struct RaiseAffineToLinalg - : public AffineRaiseToLinalgBase { - void runOnOperation() override; -}; -} // namespace - // Also want to add support for affine.for ( ) { linalg.generic } -> bigger // linalg.generic Also probably want to try to do { linalg.generc1(); // linalg.generic2(); } -> bigger linalg.generic() @@ -961,16 +954,144 @@ struct AffineForOpRaising : public OpRewritePattern { } }; -void RaiseAffineToLinalg::runOnOperation() { - RewritePatternSet patterns(&getContext()); - // TODO add the existing canonicalization patterns - // + subview of an affine apply -> subview - patterns.insert(&getContext()); +// struct RemoveIterArgs : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; +// LogicalResult matchAndRewrite(scf::ForOp forOp, +// PatternRewriter &rewriter) const override { +// if (!forOp.getRegion().hasOneBlock()) +// return failure(); +// unsigned numIterArgs = forOp.getNumRegionIterArgs(); +// auto loc = forOp->getLoc(); +// bool changed = false; +// llvm::SetVector removed; +// llvm::MapVector steps; +// auto yield = cast(forOp.getBody()->getTerminator()); +// for (unsigned i = 0; i < numIterArgs; i++) { +// auto ba = forOp.getRegionIterArgs()[i]; +// auto init = forOp.getInits()[i]; +// auto next = yield->getOperand(i); + +// auto increment = next.getDefiningOp(); +// if (!increment) +// continue; + +// Value step = nullptr; +// if (increment.getLhs() == ba) { +// step = increment.getRhs(); +// } else { +// step = increment.getLhs(); +// } +// if (!step) +// continue; + +// // If it dominates the loop entry +// if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion())) +// continue; + +// rewriter.setInsertionPointToStart(forOp.getBody()); +// Value iterNum = rewriter.create( +// loc, forOp.getInductionVar(), forOp.getLowerBound()); +// iterNum = rewriter.create(loc, iterNum, forOp.getStep()); + +// Value replacementIV = rewriter.create(loc, iterNum, step); +// replacementIV = rewriter.create(loc, replacementIV, init); + +// rewriter.replaceAllUsesWith(ba, replacementIV); + +// removed.insert(i); +// steps.insert({i, step}); +// changed = true; +// } + +// if (!changed) +// return failure(); + +// SmallVector newInits; +// for (unsigned i = 0; i < numIterArgs; i++) +// if (!removed.contains(i)) +// newInits.push_back(forOp.getInits()[i]); + +// rewriter.setInsertionPoint(forOp); +// auto newForOp = rewriter.create(loc, forOp.getLowerBound(), +// forOp.getUpperBound(), +// forOp.getStep(), newInits); +// if (!newForOp.getRegion().empty()) +// newForOp.getRegion().front().erase(); +// assert(newForOp.getRegion().empty()); +// rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), +// newForOp.getRegion().begin()); + +// SmallVector newYields; +// for (unsigned i = 0; i < numIterArgs; i++) +// if (!removed.contains(i)) +// newYields.push_back(yield->getOperand(i)); + +// rewriter.setInsertionPoint(yield); +// rewriter.replaceOpWithNewOp(yield, newYields); + +// llvm::BitVector toDelete(numIterArgs + 1); +// for (unsigned i = 0; i < numIterArgs; i++) +// if (removed.contains(i)) +// toDelete[i + 1] = true; +// newForOp.getBody()->eraseArguments(toDelete); + +// rewriter.setInsertionPoint(newForOp); +// unsigned curNewRes = 0; +// for (unsigned i = 0; i < numIterArgs; i++) { +// auto result = forOp->getResult(i); +// if (removed.contains(i)) { +// if (result.use_empty()) +// continue; + +// rewriter.setInsertionPointToStart(forOp.getBody()); +// Value iterNum = rewriter.create( +// loc, forOp.getUpperBound(), forOp.getLowerBound()); +// iterNum = +// rewriter.create(loc, iterNum, forOp.getStep()); + +// Value afterLoop = +// rewriter.create(loc, iterNum, steps[i]); +// afterLoop = +// rewriter.create(loc, afterLoop, forOp.getInits()[i]); + +// rewriter.replaceAllUsesWith(result, afterLoop); +// } else { +// rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++)); +// } +// } + +// rewriter.eraseOp(forOp); + +// return success(); +// } +// }; - GreedyRewriteConfig config; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); -} +namespace { +struct RaiseAffineToLinalg + : public AffineRaiseToLinalgBase { + + std::shared_ptr patterns; + + LogicalResult initialize(MLIRContext *context) override { + RewritePatternSet owningPatterns(context); + for (auto *dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); + for (RegisteredOperationName op : context->getRegisteredOperations()) + op.getCanonicalizationPatterns(owningPatterns, context); + + //owningPatterns.insert(&getContext()); + owningPatterns.insert(&getContext()); + + patterns = std::make_shared( + std::move(owningPatterns)); + return success(); + } + void runOnOperation() override { + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), *patterns, config); + } +}; +} // namespace namespace mlir { namespace polygeist { diff --git a/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir b/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir index fb08f31190bb..dbe09418ed75 100644 --- a/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir +++ b/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir @@ -3,33 +3,12 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> -//#trait_conv = { -// indexing_maps = [ -// affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>, -// affine_map<(d0, d1, d2, d3) -> (d2, d3)>, -// affine_map<(d0, d1, d2, d3) -> (d0, d1)> -// ], -// iterator_types = ["parallel", "parallel", "reduction", "reduction"] -//} -// -////Remember to tile basd on output -//func.func @conv(%A : tensor<130x130xf32>, %B : tensor<3x3xf32>, -// %C : tensor<128x128xf32>) -> tensor<128x128xf32> { -// %1 = linalg.generic #trait_conv -// ins(%A, %B : tensor<130x130xf32>, -// tensor<3x3xf32>) -// outs(%C : tensor<128x128xf32>) { -// ^bb0(%a: f32, %b: f32, %c: f32) : -// %d = arith.mulf %a, %b: f32 -// %e = arith.addf %c, %d: f32 -// linalg.yield %e : f32 -// } -> tensor<128x128xf32> -// return %1 : tensor<128x128xf32> -//} + memref.global @out : memref<512x64xi32> = uninitialized memref.global @rhs : memref<64x64xi32> = uninitialized memref.global @filter : memref<4x4xi32> = uninitialized memref.global @im : memref<515x67xi32> = uninitialized +// Output after debufferization // func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { // %c512 = arith.constant 512 : index // %c64 = arith.constant 64 : index @@ -48,10 +27,10 @@ memref.global @im : memref<515x67xi32> = uninitialized // %7 = arith.addi %out, %6 : i32 // linalg.yield %7 : i32 // } -> tensor<512x64xi32> - +// // %materialize = bufferization.to_memref %out : memref<512x64xi32> // memref.copy %materialize, %2 : memref<512x64xi32> to memref<512x64xi32> - +// // %conv_out = bufferization.to_tensor %2 : memref<512x64xi32> // %rhs = bufferization.to_tensor %rhs_memref : memref<64x64xi32> // %y = tensor.empty() : tensor<512x64xi32> @@ -62,6 +41,7 @@ memref.global @im : memref<515x67xi32> = uninitialized // return %c0_i32 : i32 // } +//Output after linking kernels func.func @main_opt() -> i32 attributes {llvm.linkage = #llvm.linkage} { %c512 = arith.constant 512 : index %c64 = arith.constant 64 : index @@ -91,14 +71,6 @@ func.func @main_opt() -> i32 attributes {llvm.linkage = #llvm.linkage} return %c0_i32 : i32 } -// transform.sequence failures(propagate) { -// ^bb0(%arg0 : !transform.any_op): -// %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op -// //Note that these represent the outer dimension first for tiling -// %1,%2,%3 = transform.structured.tile_using_for %0 [32,32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) -// transform.yield -// } - transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op) : // Since the %arg2 handle is associated with both elementwise operations, diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index b4bb5687ac35..a05bd5338122 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -33,26 +33,26 @@ // } - // func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { - // %c0 = arith.constant 0 : index - // %c4 = arith.constant 4 : index - // %c1 = arith.constant 1 : index - // %15 = arith.index_cast %14 : i32 to index - // %16 = arith.muli %15, %c4 : index - // %17 = arith.divui %16, %c4 : index - // %19 = memref.alloca(%17) : memref - // scf.if %12 { - // affine.for %arg4 = 0 to 17 { - // %ld = affine.load %18[3 * %arg4] : memref - // %ld2 = affine.load %18[0] : memref - // %fadd = arith.addf %ld, %ld2 : f32 - // affine.store %fadd, %19[%arg4 + 17] : memref - // } - // } - // return - // } + // func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { + // %c0 = arith.constant 0 : index + // %c4 = arith.constant 4 : index + // %c1 = arith.constant 1 : index + // %15 = arith.index_cast %14 : i32 to index + // %16 = arith.muli %15, %c4 : index + // %17 = arith.divui %16, %c4 : index + // %19 = memref.alloca(%17) : memref + // scf.if %12 { + // affine.for %arg4 = 0 to 17 { + // %ld = affine.load %18[3 * %arg4] : memref + // %ld2 = affine.load %18[0] : memref + // %fadd = arith.addf %ld, %ld2 : f32 + // affine.store %fadd, %19[%arg4 + 17] : memref + // } + // } + // return + // } -//} + // } // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { @@ -212,6 +212,52 @@ module @reduction{ } } +module @reduction_transformed{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + %sum_0 = arith.constant 0.0 : f32 + %alloca = memref.alloca() : memref<1xf32> + affine.store %sum_0, %alloca[0] : memref<1xf32> + affine.for %arg4 = 0 to 17 step 1 { + %iter_arg = affine.load %alloca[0] : memref<1xf32> + %ld1 = affine.load %18[%arg4] : memref + %sum_next = arith.addf %iter_arg, %ld1 : f32 + affine.store %sum_next, %alloca[0] : memref<1xf32> + affine.yield + } + %red = affine.load %alloca[0] : memref<1xf32> + affine.store %red, %19[0] : memref + return + } +} + +module @reduction_transformed_simplified{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + %sum_0 = arith.constant 0.0 : f32 + affine.store %sum_0, %19[0] : memref + affine.for %arg4 = 0 to 17 step 1 { + %iter_arg = affine.load %19[0] : memref + %ld1 = affine.load %18[%arg4] : memref + %sum_next = arith.addf %iter_arg, %ld1 : f32 + affine.store %sum_next, %19[0] : memref + affine.yield + } + return + } +} //TODO: Conditional store-1 module @cond_store_1 { func.func @main(%12 : i1, %14 : i32, %18 : memref ) { @@ -733,6 +779,31 @@ module @conv_1_reduction_test{ } } +//box_filter (direct store) + module @box_filter { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %2 = memref.get_global @out : memref<512x64xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 64 { + affine.for %arg2 = 0 to 4 { + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %3 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + } + } + } + } + return %c0_i32 : i32 + } + } + module @conv_loop1_test { memref.global @out : memref<512x64xi32> = uninitialized memref.global @filter : memref<4x4xi32> = uninitialized @@ -776,3 +847,245 @@ module @conv_1_reduction_test{ return %c0_i32 : i32 } } + + +module @harris_score_1{ + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + affine.for %arg2 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + %gx = affine.load %0[%arg0, %arg1] : memref<516x516xi32> + %gy = affine.load %1[%arg0, %arg1] : memref<516x516xi32> + %11 = affine.load %2[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %12 = affine.load %3[%arg5 + %arg2 * 3] : memref<9xi32> + %13 = arith.muli %11, %12 : i32 + %14 = arith.addi %gx, %13 : i32 + %15 = affine.load %4[%arg5 + %arg2 * 3] : memref<9xi32> + %16 = arith.muli %11, %15 : i32 + %17 = arith.addi %gy, %16 : i32 + affine.store %14, %0[%arg0, %arg1] : memref<516x516xi32> + affine.store %17, %1[%arg0, %arg1] : memref<516x516xi32> + } + } + } + } + %5 = memref.get_global @img_ixx : memref<512x512xi32> + %6 = memref.get_global @img_iyy : memref<512x512xi32> + %7 = memref.get_global @img_ixy : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + affine.for %arg2 = 0 to 5 { + affine.for %arg6 = 0 to 5 { + %ixx = affine.load %5[%arg0, %arg1] : memref<512x512xi32> + %iyy = affine.load %6[%arg0, %arg1] : memref<512x512xi32> + %ixy = affine.load %7[%arg0, %arg1] : memref<512x512xi32> + %11 = affine.load %0[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %12 = affine.load %1[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %13 = arith.muli %11, %11 : i32 + %14 = arith.addi %ixx, %13 : i32 + %15 = arith.muli %12, %12 : i32 + %16 = arith.addi %iyy, %15 : i32 + %17 = arith.muli %11, %12 : i32 + %18 = arith.addi %ixy, %17 : i32 + affine.store %14, %5[%arg0, %arg1] : memref<512x512xi32> + affine.store %16, %6[%arg0, %arg1] : memref<512x512xi32> + affine.store %18, %7[%arg0, %arg1] : memref<512x512xi32> + } + } + } + } + %8 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %9 = affine.load %5[%arg0, %arg1] : memref<512x512xi32> + %10 = affine.load %6[%arg0, %arg1] : memref<512x512xi32> + %11 = affine.load %7[%arg0, %arg1] : memref<512x512xi32> + %12 = arith.muli %9, %10 : i32 + %13 = arith.muli %11, %11 : i32 + %14 = arith.subi %12, %13 : i32 + %15 = arith.addi %9, %10 : i32 + %16 = arith.muli %15, %c4_i32 : i32 + %17 = arith.muli %16, %15 : i32 + %18 = arith.subi %14, %17 : i32 + affine.store %18, %8[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} + +module @harris_score_2 { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + %9:2 = affine.for %arg2 = 0 to 3 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32) -> (i32, i32) { + %10:2 = affine.for %arg5 = 0 to 3 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (i32, i32) { + %11 = affine.load %2[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %12 = affine.load %3[%arg5 + %arg2 * 3] : memref<9xi32> + %13 = arith.muli %11, %12 : i32 + %14 = arith.addi %arg7, %13 : i32 + %15 = affine.load %4[%arg5 + %arg2 * 3] : memref<9xi32> + %16 = arith.muli %11, %15 : i32 + %17 = arith.addi %arg6, %16 : i32 + affine.yield %17, %14 : i32, i32 + } + affine.yield %10#0, %10#1 : i32, i32 + } + affine.store %9#1, %0[%arg0, %arg1] : memref<516x516xi32> + affine.store %9#0, %1[%arg0, %arg1] : memref<516x516xi32> + } + } + %5 = memref.get_global @img_ixx : memref<512x512xi32> + %6 = memref.get_global @img_iyy : memref<512x512xi32> + %7 = memref.get_global @img_ixy : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %9:3 = affine.for %arg2 = 0 to 5 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32) -> (i32, i32, i32) { + %10:3 = affine.for %arg6 = 0 to 5 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (i32, i32, i32) { + %11 = affine.load %0[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %12 = affine.load %1[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %13 = arith.muli %11, %11 : i32 + %14 = arith.addi %arg9, %13 : i32 + %15 = arith.muli %12, %12 : i32 + %16 = arith.addi %arg8, %15 : i32 + %17 = arith.muli %11, %12 : i32 + %18 = arith.addi %arg7, %17 : i32 + affine.yield %18, %16, %14 : i32, i32, i32 + } + affine.yield %10#0, %10#1, %10#2 : i32, i32, i32 + } + affine.store %9#2, %5[%arg0, %arg1] : memref<512x512xi32> + affine.store %9#1, %6[%arg0, %arg1] : memref<512x512xi32> + affine.store %9#0, %7[%arg0, %arg1] : memref<512x512xi32> + } + } + %8 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %9 = affine.load %5[%arg0, %arg1] : memref<512x512xi32> + %10 = affine.load %6[%arg0, %arg1] : memref<512x512xi32> + %11 = affine.load %7[%arg0, %arg1] : memref<512x512xi32> + %12 = arith.muli %9, %10 : i32 + %13 = arith.muli %11, %11 : i32 + %14 = arith.subi %12, %13 : i32 + %15 = arith.addi %9, %10 : i32 + %16 = arith.muli %15, %c4_i32 : i32 + %17 = arith.muli %16, %15 : i32 + %18 = arith.subi %14, %17 : i32 + affine.store %18, %8[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} + +module @harris_score_local { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @coeffs_x : memref<9xi32> + %1 = memref.get_global @coeffs_y : memref<9xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + affine.for %arg2 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + %gx = affine.load %alloca_3[%arg0, %arg1] : memref<516x516xi32> + %gy = affine.load %alloca_2[%arg0, %arg1] : memref<516x516xi32> + %5 = affine.load %alloca_4[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %6 = affine.load %0[%arg5 + %arg2 * 3] : memref<9xi32> + %7 = arith.muli %5, %6 : i32 + %8 = arith.addi %gx, %7 : i32 + %9 = affine.load %1[%arg5 + %arg2 * 3] : memref<9xi32> + %10 = arith.muli %5, %9 : i32 + %11 = arith.addi %gy, %10 : i32 + affine.store %8, %alloca_3[%arg0, %arg1] : memref<516x516xi32> + affine.store %11, %alloca_2[%arg0, %arg1] : memref<516x516xi32> + } + } + } + } + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %3:3 = affine.for %arg2 = 0 to 5 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32) -> (i32, i32, i32) { + %4:3 = affine.for %arg6 = 0 to 5 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (i32, i32, i32) { + %ixx = affine.load %alloca_1[%arg0, %arg1] : memref<512x512xi32> + %iyy = affine.load %alloca_0[%arg0, %arg1] : memref<512x512xi32> + %ixy = affine.load %alloca[%arg0, %arg1] : memref<512x512xi32> + %5 = affine.load %alloca_3[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %6 = affine.load %alloca_2[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %7 = arith.muli %5, %5 : i32 + %8 = arith.addi %arg9, %7 : i32 + %9 = arith.muli %6, %6 : i32 + %10 = arith.addi %arg8, %9 : i32 + %11 = arith.muli %5, %6 : i32 + %12 = arith.addi %arg7, %11 : i32 + affine.yield %12, %10, %8 : i32, i32, i32 + } + affine.yield %4#0, %4#1, %4#2 : i32, i32, i32 + } + affine.store %3#2, %alloca_1[%arg0, %arg1] : memref<512x512xi32> + affine.store %3#1, %alloca_0[%arg0, %arg1] : memref<512x512xi32> + affine.store %3#0, %alloca[%arg0, %arg1] : memref<512x512xi32> + } + } + %2 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %3 = affine.load %alloca_1[%arg0, %arg1] : memref<512x512xi32> + %4 = affine.load %alloca_0[%arg0, %arg1] : memref<512x512xi32> + %5 = affine.load %alloca[%arg0, %arg1] : memref<512x512xi32> + %6 = arith.muli %3, %4 : i32 + %7 = arith.muli %5, %5 : i32 + %8 = arith.subi %6, %7 : i32 + %9 = arith.addi %3, %4 : i32 + %10 = arith.muli %9, %c4_i32 : i32 + %11 = arith.muli %10, %9 : i32 + %12 = arith.subi %8, %11 : i32 + affine.store %12, %2[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} From 701f25a51b3f30f798e01e506dfd568ae9cfe78e Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Sat, 12 Oct 2024 15:41:27 -0700 Subject: [PATCH 27/77] Added removal of iter_args for affine loops --- include/polygeist/Passes/Passes.h | 2 +- include/polygeist/Passes/Passes.td | 4 +- lib/polygeist/Passes/CMakeLists.txt | 2 +- lib/polygeist/Passes/RemoveIterArgs.cpp | 288 ++++++++++++++++++++++++ 4 files changed, 292 insertions(+), 4 deletions(-) create mode 100644 lib/polygeist/Passes/RemoveIterArgs.cpp diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index 96ecf5b32003..ad7e2fc14fc2 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -32,7 +32,7 @@ std::unique_ptr createOpenMPOptPass(); std::unique_ptr createCanonicalizeForPass(); std::unique_ptr createRaiseSCFToAffinePass(); std::unique_ptr createRaiseAffineToLinalgPass(); -std::unique_ptr createRemoveSCFIterArgsPass(); +std::unique_ptr createRemoveIterArgsPass(); std::unique_ptr createCPUifyPass(StringRef method = ""); std::unique_ptr createBarrierRemovalContinuation(); std::unique_ptr detectReductionPass(); diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 0d3116f82c71..7d5f2315f4ce 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -151,9 +151,9 @@ def SCFRaiseToAffine : Pass<"raise-scf-to-affine"> { ]; } -def RemoveSCFIterArgs : Pass<"remove-scf-iter-args"> { +def RemoveIterArgs : Pass<"remove-iter-args"> { let summary = "Remove scf iter args"; - let constructor = "mlir::polygeist::createRemoveSCFIterArgsPass()"; + let constructor = "mlir::polygeist::createRemoveIterArgsPass()"; let dependentDialects = [ "affine::AffineDialect", "scf::SCFDialect", diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index bcc6de07193d..f98813fb15b5 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -11,7 +11,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms OpenMPOpt.cpp BarrierRemovalContinuation.cpp RaiseToAffine.cpp - RemoveScfIterArgs.cpp + RemoveIterArgs.cpp RaiseToLinalg.cpp ParallelLower.cpp TrivialUse.cpp diff --git a/lib/polygeist/Passes/RemoveIterArgs.cpp b/lib/polygeist/Passes/RemoveIterArgs.cpp new file mode 100644 index 000000000000..b3b0ac7302a4 --- /dev/null +++ b/lib/polygeist/Passes/RemoveIterArgs.cpp @@ -0,0 +1,288 @@ +#include "PassDetails.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "polygeist/Passes/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "remove-scf-iter-args" + +using namespace mlir; +using namespace mlir::arith; +using namespace polygeist; +using namespace scf; +using namespace affine; + +struct RemoveSCFIterArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + + ModuleOp module = forOp->getParentOfType(); + if (!forOp.getRegion().hasOneBlock()) + return failure(); + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + auto loc = forOp->getLoc(); + bool changed = false; + llvm::SetVector removed; + llvm::MapVector steps; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < numIterArgs; i++) { + auto ba = forOp.getRegionIterArgs()[i]; + auto init = forOp.getInits()[i]; + auto lastOp = yieldOp->getOperand(i); + + //General Case(TODO): + //ALGo: + // 1. Create an alloca(stack) variable + // How to know it's dims? It should be based on number of reduction loops + // 2. Initialize it with init value just outside the for loop if init value is non-zero + // 3. memref.load that value in the for loop + // 4. Replace all the uses of the iter_arg with the loaded value + // 5. Add a memref.store for the value to be yielded + // 6. Replace all uses of for-loops yielded value with a single inserted memref.load + //Special case: + //ALGo: + //Optimize away memref.store and memref.load, if the only users of memref.load are memref.store (can use affine-scalrep pass for that ? No it does store to load forwarding) + //What we need is forwarding of local store to final store and deleting the intermediate alloca created. This is only possible if the user of alloca is a storeOp. + // 1. Identify the single store of the for loop result + // 2. Initialize it with iter arg init, outside the for loop. (TODO) + // 3. Do a load from the memref + // 4. move the store to memref inside the loop. + + auto result = forOp.getResult(i); + if(result.hasOneUse()) { + auto storeOp = dyn_cast(*result.getUsers().begin()); + if(storeOp) + { + { + rewriter.setInsertionPointToStart(forOp.getBody()); + auto memrefLoad = rewriter.create( + forOp.getLoc(), storeOp.getMemref(), storeOp.getIndices()); + rewriter.replaceAllUsesWith(ba, memrefLoad.getResult()); + } + { + rewriter.setInsertionPoint(yieldOp); + rewriter.create(forOp.getLoc(), lastOp, storeOp.getMemref(), + storeOp.getIndices()); + storeOp.erase(); + } + } + else{ + return failure(); + } + } + //else{ + // alloca = rewriter.create( + // forOp.getLoc(), MemRefType::get(ArrayRef(), forOp.getType()), + // ValueRange()); + // //Skipping init for now + + + // auto memrefLoad = rewriter.create( + // forOp.getLoc(), alloca.getMemref(), op.getIndices()); + // rewriter.replaceOp(op, memrefLoad.getResult()); + + // rewriter.create(forOp.getLoc(), lastOp, alloca, + // forOp.getBody()->getArguments()); + + // rewriter.replaceAllUsesWith(result,) + //} + + rewriter.setInsertionPointToStart(forOp.getBody()); + //rewriter.replaceAllUsesWith(ba, replacementIV); + changed = true; + } + + if (!changed) + return failure(); + + rewriter.setInsertionPoint(forOp); + auto newForOp = rewriter.create(loc, forOp.getLowerBound(), + forOp.getUpperBound(), + forOp.getStep()); + if (!newForOp.getRegion().empty()) + newForOp.getRegion().front().erase(); + assert(newForOp.getRegion().empty()); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().begin()); + + //Delete region args + llvm::BitVector toDelete(numIterArgs + 1); + for (unsigned i = 0; i < numIterArgs; i++) + toDelete[i + 1] = true; + newForOp.getBody()->eraseArguments(toDelete); + + SmallVector newYields; + { + ValueRange empty; + rewriter.setInsertionPoint(yieldOp); + auto newYieldOp = rewriter.create(loc); + //rewriter.replaceOpWithNewOp(yieldOp, newYieldOp); + rewriter.eraseOp(yieldOp); + } + + rewriter.setInsertionPoint(newForOp); + rewriter.eraseOp(forOp); + + return success(); + } +}; + +struct RemoveAffineIterArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(affine::AffineForOp forOp, + PatternRewriter &rewriter) const override { + + ModuleOp module = forOp->getParentOfType(); + if (!forOp.getRegion().hasOneBlock()) + return failure(); + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + auto loc = forOp->getLoc(); + bool changed = false; + llvm::SetVector removed; + llvm::MapVector steps; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < numIterArgs; i++) { + auto ba = forOp.getRegionIterArgs()[i]; + auto init = forOp.getInits()[i]; + auto lastOp = yieldOp->getOperand(i); + + //General Case(TODO): + //ALGo: + // 1. Create an alloca(stack) variable + // How to know it's dims? It should be based on number of reduction loops + // 2. Initialize it with init value just outside the for loop if init value is non-zero + // 3. memref.load that value in the for loop + // 4. Replace all the uses of the iter_arg with the loaded value + // 5. Add a memref.store for the value to be yielded + // 6. Replace all uses of for-loops yielded value with a single inserted memref.load + //Special case: + //ALGo: + //Optimize away memref.store and memref.load, if the only users of memref.load are memref.store (can use affine-scalrep pass for that ? No it does store to load forwarding) + //What we need is forwarding of local store to final store and deleting the intermediate alloca created. This is only possible if the user of alloca is a storeOp. + // 1. Identify the single store of the for loop result + // 2. Initialize it with iter arg init, outside the for loop. (TODO) + // 3. Do a load from the memref + // 4. move the store to memref inside the loop. + + auto result = forOp.getResult(i); + if(result.hasOneUse()) { + auto storeOp = dyn_cast(*result.getUsers().begin()); + if(storeOp) + { + { + rewriter.setInsertionPointToStart(forOp.getBody()); + auto memrefLoad = rewriter.create( + forOp.getLoc(), storeOp.getMemref(), storeOp.getMap(), storeOp.getMapOperands()); + rewriter.replaceAllUsesWith(ba, memrefLoad.getResult()); + } + { + rewriter.setInsertionPoint(yieldOp); + rewriter.create(forOp.getLoc(), lastOp, storeOp.getMemref(), + storeOp.getMap(), storeOp.getMapOperands()); + storeOp.erase(); + } + } + else{ + return failure(); + } + } + //else{ + // alloca = rewriter.create( + // forOp.getLoc(), MemRefType::get(ArrayRef(), forOp.getType()), + // ValueRange()); + // //Skipping init for now + + + // auto memrefLoad = rewriter.create( + // forOp.getLoc(), alloca.getMemref(), op.getIndices()); + // rewriter.replaceOp(op, memrefLoad.getResult()); + + // rewriter.create(forOp.getLoc(), lastOp, alloca, + // forOp.getBody()->getArguments()); + + // rewriter.replaceAllUsesWith(result,) + //} + + rewriter.setInsertionPointToStart(forOp.getBody()); + //rewriter.replaceAllUsesWith(ba, replacementIV); + changed = true; + } + + if (!changed) + return failure(); + + rewriter.setInsertionPoint(forOp); + auto newForOp = rewriter.create(loc, forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), + forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), + forOp.getStep()); + + if (!newForOp.getRegion().empty()) + newForOp.getRegion().front().erase(); + assert(newForOp.getRegion().empty()); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().begin()); + + //Delete region args + llvm::BitVector toDelete(numIterArgs + 1); + for (unsigned i = 0; i < numIterArgs; i++) + toDelete[i + 1] = true; + newForOp.getBody()->eraseArguments(toDelete); + + SmallVector newYields; + { + ValueRange empty; + rewriter.setInsertionPoint(yieldOp); + auto newYieldOp = rewriter.create(loc); + //rewriter.replaceOpWithNewOp(yieldOp, newYieldOp); + rewriter.eraseOp(yieldOp); + } + + rewriter.setInsertionPoint(newForOp); + rewriter.eraseOp(forOp); + + return success(); + } +}; + +namespace { +struct RemoveIterArgs + : public RemoveIterArgsBase { + + void runOnOperation() override { + GreedyRewriteConfig config; + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + patterns.insert(patterns.getContext()); + patterns.insert(patterns.getContext()); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + signalPassFailure(); + return; + } + } +}; +} // namespace + +namespace mlir { +namespace polygeist { +std::unique_ptr createRemoveIterArgsPass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir \ No newline at end of file From d285fb5e41d7e4c878784943384034f8a97b8f12 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Sat, 12 Oct 2024 16:34:01 -0700 Subject: [PATCH 28/77] Temporary reverted pass registeration as the code was failing --- lib/polygeist/Passes/RaiseToLinalg.cpp | 160 +++++-------------------- 1 file changed, 32 insertions(+), 128 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index dac831af5477..46021a556717 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -954,145 +954,49 @@ struct AffineForOpRaising : public OpRewritePattern { } }; -// struct RemoveIterArgs : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(scf::ForOp forOp, -// PatternRewriter &rewriter) const override { -// if (!forOp.getRegion().hasOneBlock()) -// return failure(); -// unsigned numIterArgs = forOp.getNumRegionIterArgs(); -// auto loc = forOp->getLoc(); -// bool changed = false; -// llvm::SetVector removed; -// llvm::MapVector steps; -// auto yield = cast(forOp.getBody()->getTerminator()); -// for (unsigned i = 0; i < numIterArgs; i++) { -// auto ba = forOp.getRegionIterArgs()[i]; -// auto init = forOp.getInits()[i]; -// auto next = yield->getOperand(i); - -// auto increment = next.getDefiningOp(); -// if (!increment) -// continue; - -// Value step = nullptr; -// if (increment.getLhs() == ba) { -// step = increment.getRhs(); -// } else { -// step = increment.getLhs(); -// } -// if (!step) -// continue; - -// // If it dominates the loop entry -// if (!step.getParentRegion()->isProperAncestor(&forOp.getRegion())) -// continue; - -// rewriter.setInsertionPointToStart(forOp.getBody()); -// Value iterNum = rewriter.create( -// loc, forOp.getInductionVar(), forOp.getLowerBound()); -// iterNum = rewriter.create(loc, iterNum, forOp.getStep()); - -// Value replacementIV = rewriter.create(loc, iterNum, step); -// replacementIV = rewriter.create(loc, replacementIV, init); - -// rewriter.replaceAllUsesWith(ba, replacementIV); - -// removed.insert(i); -// steps.insert({i, step}); -// changed = true; -// } - -// if (!changed) -// return failure(); - -// SmallVector newInits; -// for (unsigned i = 0; i < numIterArgs; i++) -// if (!removed.contains(i)) -// newInits.push_back(forOp.getInits()[i]); - -// rewriter.setInsertionPoint(forOp); -// auto newForOp = rewriter.create(loc, forOp.getLowerBound(), -// forOp.getUpperBound(), -// forOp.getStep(), newInits); -// if (!newForOp.getRegion().empty()) -// newForOp.getRegion().front().erase(); -// assert(newForOp.getRegion().empty()); -// rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), -// newForOp.getRegion().begin()); - -// SmallVector newYields; -// for (unsigned i = 0; i < numIterArgs; i++) -// if (!removed.contains(i)) -// newYields.push_back(yield->getOperand(i)); - -// rewriter.setInsertionPoint(yield); -// rewriter.replaceOpWithNewOp(yield, newYields); - -// llvm::BitVector toDelete(numIterArgs + 1); -// for (unsigned i = 0; i < numIterArgs; i++) -// if (removed.contains(i)) -// toDelete[i + 1] = true; -// newForOp.getBody()->eraseArguments(toDelete); - -// rewriter.setInsertionPoint(newForOp); -// unsigned curNewRes = 0; -// for (unsigned i = 0; i < numIterArgs; i++) { -// auto result = forOp->getResult(i); -// if (removed.contains(i)) { -// if (result.use_empty()) -// continue; - -// rewriter.setInsertionPointToStart(forOp.getBody()); -// Value iterNum = rewriter.create( -// loc, forOp.getUpperBound(), forOp.getLowerBound()); -// iterNum = -// rewriter.create(loc, iterNum, forOp.getStep()); - -// Value afterLoop = -// rewriter.create(loc, iterNum, steps[i]); -// afterLoop = -// rewriter.create(loc, afterLoop, forOp.getInits()[i]); - -// rewriter.replaceAllUsesWith(result, afterLoop); -// } else { -// rewriter.replaceAllUsesWith(result, newForOp->getResult(curNewRes++)); -// } -// } - -// rewriter.eraseOp(forOp); +// namespace { +// struct RaiseAffineToLinalg +// : public AffineRaiseToLinalgBase { +// std::shared_ptr patterns; + +// LogicalResult initialize(MLIRContext *context) override { +// RewritePatternSet owningPatterns(context); +// for (auto *dialect : context->getLoadedDialects()) +// dialect->getCanonicalizationPatterns(owningPatterns); +// for (RegisteredOperationName op : context->getRegisteredOperations()) +// op.getCanonicalizationPatterns(owningPatterns, context); + +// owningPatterns.insert(&getContext()); + +// patterns = std::make_shared( +// std::move(owningPatterns)); // return success(); // } +// void runOnOperation() override { +// GreedyRewriteConfig config; +// (void)applyPatternsAndFoldGreedily(getOperation(), *patterns, config); +// } // }; +// } // namespace namespace { struct RaiseAffineToLinalg : public AffineRaiseToLinalgBase { - - std::shared_ptr patterns; - - LogicalResult initialize(MLIRContext *context) override { - RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) - dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) - op.getCanonicalizationPatterns(owningPatterns, context); - - //owningPatterns.insert(&getContext()); - owningPatterns.insert(&getContext()); - - patterns = std::make_shared( - std::move(owningPatterns)); - return success(); - } - void runOnOperation() override { - GreedyRewriteConfig config; - (void)applyPatternsAndFoldGreedily(getOperation(), *patterns, config); - } + void runOnOperation() override; }; } // namespace +void RaiseAffineToLinalg::runOnOperation() { + RewritePatternSet patterns(&getContext()); + // TODO add the existing canonicalization patterns + // + subview of an affine apply -> subview + patterns.insert(&getContext()); + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + namespace mlir { namespace polygeist { std::unique_ptr createRaiseAffineToLinalgPass() { From c40e7a94b80b2b394844e764ea9a66e9e6ef17f3 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 15 Oct 2024 16:44:34 -0700 Subject: [PATCH 29/77] WIP commit --- lib/polygeist/Ops.cpp | 183 ++++++++++++++++++++++++++++-------------- 1 file changed, 122 insertions(+), 61 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index bfe1a6eab2d7..3d83ebf30afa 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5733,31 +5733,92 @@ struct MulDivMul : public OpRewritePattern { } }; -//struct SubMapOpCanonicalize : public OpRewritePattern { +struct SubMapOpCanonicalize : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SubMapOp op, + PatternRewriter &rewriter) const override { + /// if submap %x is identity map and has the same size as the static size of %x + ///. replace submap with memref.cast of memref<4x5xf32> to memref + /// %x = ... : memref<4x5xf32> + // %y = polygeist.submap %x(#identity_map, %constant_4, %constant_5) : memref<4x5xf32> -> memref + // + //. becomes + // + /// %x = ... : memref<4x5xf32> + // %y = memref.cast %x : memref<4x5xf32> -> memref + // + AffineMap submap_map = subMapOp.getMap(); + auto submap_operands = subMapOp.getSymbols(); + auto source_memref = subMapOp.getMemref(); + bool isIdentity = submap_map.isIdentity() + bool isInputSameDim = llvm::all_of(llvm::zip(submap_operands, cast(source_memref.getType()).getSizes()), [&](auto pair) { + return pair.first == pair.second; + }); + if (isIdentity && isInputSameD) + m { + ::zip( e + ubma_operands, source_memref.getSizes())) if () { +e rewriter.replaceOpWithNewOp(op, op.getType(), op.getMemref()); + return success(); + } + + /// if we have a submap o +} f + a sub %y = polygeist.submap (%x, ...)map we can just replace with a si ngle s + u pol yge ist. submap (%u, ...) + // %y = polygeist.submap (%x, ...) + // + + // becomes + // + // %y = polygeist.submap (%u, ...) + // + if (aut o sapOp = op.getMemR:SubMapOp>() + ) { + auto load_map = op.getAffineMap(); + auto submap_map = subMapOp.getAffineMap();; + auto new_map = submap_map.compose(load_map); + + SmallVector operands; + operands.append(subMapOp.getSymbols().begin(), subMapOp.getSymbols().end()); + operands.append(op.getSymbols().begin(), op.getSymbols().end()); + + operands.append(op.getSizes().begin(), op.getSizes().end()); + + rewriter.replaceOpWithNewOp(op, op.getType(), new_map, operands); + return succcess(); + } + + return failure(); + } +}; + + +// struct LinalgOfSubmap : public OpRewritePattern { // using OpRewritePattern::OpRewritePattern; // LogicalResult matchAndRewrite(linalg::GenericOp gen, // PatternRewriter &rewriter) const override { -// + // // Canonicalization 1 linalg.generic of map of submap. -> linalg.generic of map of submap // //. iff the submap's affine map != identity // //. replace inner affine map with composition -// -// + + // // Canonicalizeation 3: submap which only sets bounds, of an input memref with the same bounds -> noop / cast -// -// + + // // Canonicalization 1.5 (mix of 1/2) // //. linalg_map = identity a[i,j,x,y] -> u[i+x][j+y] // //. linalg_map = [i,j,x,y]->(i+x,j+y) a[i,j] -> u[i,j]. # but still keeping the upper loop limit // //. 1 -// -// + + // // a[i] -> x[] -// + // // a[1] -> x[] // // a[2] -> x[] -// -// + + // // a[i,j] = x[map(i,j)]. ; the subbmap op // //a[i+x][j+y] : submap defines iteration var 0 goes from 0 ... A0. and var 1 goes from 0 ... A1 // //b[x][y] @@ -5768,7 +5829,7 @@ struct MulDivMul : public OpRewritePattern { // //. for (x : ...) // //. for (y : ...) // // c[i+x][j+y] += a[i+x][j+y] * b[x][y] -// + // // a[i+x][j+y] // // c[i+x][j+y] // // for (i : ...) @@ -5776,26 +5837,26 @@ struct MulDivMul : public OpRewritePattern { // //. for (x : ...) // //. for (y : ...) // // c[i+x][j+y] += a[i+x][j+y] -// -// + + // //x[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps // //b[x][y] // //c[i+x][j+y] -// -// + + // // requirement here, is that all linalg.generic loop bounds must be solvable after replacement // // for example, this would not be permissible // // a[i] -> x[]. ; a = submap memref -> memref<100xf32> // // out[] -// + // // This cannot be replaced since now the linalg generic iteration variable i cannot be solved for -// -// -// + + + // for (auto &&[op, opmap] : gen.getInputsAndMaps()) { // if (auto submap = op.getDefiningOp()) { // bool solvable = false; -// + // /// Cannoicalization 2: index removal // //. x[i, j] -> v[i]. can we get rid of j? // //. Are input indices defined by other ops, and if so, can we simplify @@ -5804,7 +5865,7 @@ struct MulDivMul : public OpRewritePattern { // //. For each index which is solvable from 2) // // if it can either be removed from the submap, or combined with another index in the submap, // // remove it from the submap -// + // SmallVector exprs; // for (auto [op2, map] : gen.getInputAndMaps()) { // if (op != op2) { @@ -5822,19 +5883,19 @@ struct MulDivMul : public OpRewritePattern { // } // SmallSet solvable; // linalg.determineSolvableIndices(solvable, exprs); -// + // SmallSet notsolvable = allvariables - solvable; -// + // //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps // //b[x][y] // //c[i+x][j+y] // // Supose we're solving for a // // Here exprs would contain all the affineexprs from b and c. (aka inputs - {x}) -// + // // {x, y, i+x, j+y} // // Running a solver allows us to uniquely solve for all of, x, y, i, and j with these expressoin // // In this case we can attempt to remove dependence on x, y, i, j -// + // // If however we had // //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps // //b[x][y] @@ -5842,52 +5903,52 @@ struct MulDivMul : public OpRewritePattern { // // we would solve with {x, y, i+x, y} // // Running a solver we would be able to sole for {x, y, i} but not solve for j // // In this case we can attempt to remove dependence on x, y, i, but not on j -// + // // let's take easiest one where a is just broadcasting a constant to all input indices // // a = submap (m,n) -> u[] // // a[i+x, j+y] For all input indices which are uniquely solvable, here that is both // //. index 0 = i + x // //. and index 1 = j + y // // set the input map to compose with the submap's affine map -// -// + + // /// Easy special case // if (notsolvable.size() == 0) { -// -// + + // replace opmap with submap.compose(opmap) taking into account the the ConstantIntRanges // // Easy case // } -// + // // We now have two maps with different meanings // // Let |N| be the number of loop variables in the linalg.generic // // Let |M| be length(submap.getType().getShape()) // // Let |Q| be length(submap.getInput().getType().getShape()), number of dimensions of input operand to the submap -// + // // opmap from the linalg.generic which takes linalg.generic loop indices |N| -> inputs to the submap op. |M| -// + // // submap.map. submap op. which takes input indices |M|. -> indices for the corresponing base memref |Q| -// + // // Example -// + // //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps // //b[x][y] // //c[i+x][j+y] -// + // // a = submap (w,p) -> u[c + 2 * p] -// + // // %c = myop.constant() // // %a = submap a[w, p] -> u[%c + 2 * p] // //. linalg.generic %a %b %c a.map (x,y,i,j) -> a[x+i,y+j] { // // } -// + // // N = 4 = |{i,j,x,u}| // // M = 2 = dim(a) . a is 2 dims // // Q = 1. dim(u) -// + // SmallVector newLinalgExprs; // SmallVector newSubmapExprs; -// + // SmallVector legalIndices; // // We iterate for all |M| expressions of the opmap // for (auto &&[i, linalgexpr] : llvm::enumerate(opmap.getExprs())) { @@ -5904,42 +5965,42 @@ struct MulDivMul : public OpRewritePattern { // // notsolvable.pop(var); // } // } -// + // if (legal) // legalIndices.push_back(i); // } -// + // // The non-special case version // // j is not solvable // //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps // //b[x][y] // //c[i+x][y] -// + // // because j is not solvable we cannot move any expressions depending on j (in this case p depends on j) // //. and the underlying sub expressions depending j, in this case via p are: // // a[1] = w + 4 and a[2] = w + 7 // // define a(w,p) -> u[c + 2 * p, w + 4, w + 7] -// + // // with the general case optimization v0. [just moving expressions up] -// + // //a2[map(i+x, j+y), i + x + 4, i + x + 7] pass in the outermost one with correspondidng composed maps // //b[x][y] // //c[i+x][y] -// + // // define a2(w, p) -> u[c + 2 * p] -// + // // with the general case optimization v1. [just eliminating unnecessary indices] -// + // //a2[map(j+y), i + x + 4, i + x + 7] pass in the outermost one with correspondidng composed maps // //b[x][y] // //c[i+x][y] -// + // // define a2(p) -> u[c + 2 * p] -// + // // So this optimization generally moves expression from the submap into the linalg map // // and it it also removes unnecessary indices into the submap -// -// + + // // If the entire submap is legal to inline, the solution is simple, replace the linalg // // map with itself composed with the submap, and replace the original submap with the identity op // if (legalIndices.size() == opmap.getExprs().size()) { @@ -5950,7 +6011,7 @@ struct MulDivMul : public OpRewritePattern { // newSubmapExprs = Affine::getIdentityMap(submap.getOperand().getShape().size()).getExprs(); // } else { // SmallVector illegalIndices = allIndices - legalIndices; -// + // // We can alternatively re-index maps which are solely functions of legal indices. // for (auto &&[i, submapexpr] : llvm::enumerate(submap.getAffineMap().getExprs())) { // if (submapexpr is a function of any illegal indicies) { @@ -5964,16 +6025,16 @@ struct MulDivMul : public OpRewritePattern { // } // } // } -// + // if (solvable) { // // replace the input to the generic with the input to the submap, and the new map // return success(); // } // } // } -// -// -// + + + // for (auto op : gen.getOutputs()) { // if (auto submap = op.getDefiningOp()) { // bool solvable = false; @@ -5984,11 +6045,11 @@ struct MulDivMul : public OpRewritePattern { // } // } // } -// -// + + // return failure(); // } -//}; +// }; static llvm::cl::opt BufferElim("enable-buffer-elim", llvm::cl::init(true), From 788a3c4426b6ab4aafec27f83c1fa5fb002473ab Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 18 Oct 2024 09:55:01 -0700 Subject: [PATCH 30/77] Added submap of submap canonicalizer with test- failing --- lib/polygeist/Ops.cpp | 62 ++++++++-------------- test/polygeist-opt/submapcanonicalize.mlir | 34 +++++++++++- 2 files changed, 55 insertions(+), 41 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 3d83ebf30afa..fa7cb6c283e9 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5735,7 +5735,7 @@ struct MulDivMul : public OpRewritePattern { struct SubMapOpCanonicalize : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SubMapOp op, + LogicalResult matchAndRewrite(SubmapOp op, PatternRewriter &rewriter) const override { /// if submap %x is identity map and has the same size as the static size of %x ///. replace submap with memref.cast of memref<4x5xf32> to memref @@ -5747,48 +5747,32 @@ struct SubMapOpCanonicalize : public OpRewritePattern { /// %x = ... : memref<4x5xf32> // %y = memref.cast %x : memref<4x5xf32> -> memref // - AffineMap submap_map = subMapOp.getMap(); - auto submap_operands = subMapOp.getSymbols(); - auto source_memref = subMapOp.getMemref(); - bool isIdentity = submap_map.isIdentity() - bool isInputSameDim = llvm::all_of(llvm::zip(submap_operands, cast(source_memref.getType()).getSizes()), [&](auto pair) { - return pair.first == pair.second; + auto source_memref = op.getMemref(); + bool isIdentity = op.getMap().isIdentity(); + bool isInputSameDim = llvm::all_of(llvm::zip_equal(op.getSizes(), cast(source_memref.getType()).getShape()), [&](auto pair) { + if (std::get<1>(pair) == -1) + return false; + APInt matched; + if (matchPattern(std::get<0>(pair), m_ConstantInt(&matched))) { + return std::get<1>(pair) == matched; + } + return false; }); - if (isIdentity && isInputSameD) - m { - ::zip( e - ubma_operands, source_memref.getSizes())) if () { -e rewriter.replaceOpWithNewOp(op, op.getType(), op.getMemref()); - return success(); - } - - /// if we have a submap o -} f - a sub %y = polygeist.submap (%x, ...)map we can just replace with a si ngle s - u pol yge ist. submap (%u, ...) - // %y = polygeist.submap (%x, ...) - // - - // becomes - // - // %y = polygeist.submap (%u, ...) - // - if (aut o sapOp = op.getMemR:SubMapOp>() - ) { - auto load_map = op.getAffineMap(); - auto submap_map = subMapOp.getAffineMap();; + if (isIdentity && isInputSameDim) { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getMemref()); + return success(); + } + if (auto sapOp = source_memref.getDefiningOp()) { + auto load_map = op.getMap(); + auto submap_map = sapOp.getMap(); auto new_map = submap_map.compose(load_map); - SmallVector operands; - operands.append(subMapOp.getSymbols().begin(), subMapOp.getSymbols().end()); operands.append(op.getSymbols().begin(), op.getSymbols().end()); - + operands.append(op.getSymbols().begin(), op.getSymbols().end()); operands.append(op.getSizes().begin(), op.getSizes().end()); - - rewriter.replaceOpWithNewOp(op, op.getType(), new_map, operands); - return succcess(); + rewriter.replaceOpWithNewOp(op, op.getType(), sapOp.getMemref(), operands, new_map); + return success(); } - return failure(); } }; @@ -6283,6 +6267,6 @@ class DimSubMap final : public OpRewritePattern { void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - //results.insert(context); - results.insert(context); + results.insert(context); + //results.insert(context); } diff --git a/test/polygeist-opt/submapcanonicalize.mlir b/test/polygeist-opt/submapcanonicalize.mlir index 3e186911f677..21f3e72fb5a1 100644 --- a/test/polygeist-opt/submapcanonicalize.mlir +++ b/test/polygeist-opt/submapcanonicalize.mlir @@ -1,6 +1,6 @@ // RUN: polygeist-opt -canonicalize %s | FileCheck %s #map = affine_map<(d0)[s0, s1] -> (d0 * s0, d0 * s1)> -module { +module @submap_to_load__store{ func.func private @use(i32) func.func @f(%arg0: memref, %arg1 : index, %arg2 : index, %arg3 : index) { @@ -38,4 +38,34 @@ module { // CHECK-NEXT: affine.store %arg4, %arg0[(%arg5 + symbol(%arg3) + 5) * symbol(%arg1), (%arg5 + symbol(%arg3) + 5) * symbol(%arg2)] : memref // CHECK-NEXT: } // CHECK-NEXT: return -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } + +#map19 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> +#map20 = affine_map<(d0, d1, d2, d3) -> (d1, d0)> +#map21 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +#map22 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +module @conv_1 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = "polygeist.submap"(%0, %c4, %c4, %c64, %c512) <{map = #map19}> : (memref<515x67xi32>, index, index, index, index) -> memref<4x4x64x512xi32> + %ssmap = "polygeist.submap"(%3, %c4, %c4, %c64, %c512) <{map = #map22}> : (memref<4x4x64x512xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c4, %c4, %c64, %c512) <{map = #map20}> : (memref<4x4xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c4, %c4, %c64, %c512) <{map = #map21}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%ssmap, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } +} \ No newline at end of file From 82652168a43ae32761eca9b46848eb2e716ec3da Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 25 Oct 2024 00:08:43 -0700 Subject: [PATCH 31/77] Added canonicalization for linalg with submap and test cases --- lib/polygeist/Ops.cpp | 99 ++- test/polygeist-opt/raised_with_submap.mlir | 883 +++++++++++++++++++++ 2 files changed, 980 insertions(+), 2 deletions(-) create mode 100644 test/polygeist-opt/raised_with_submap.mlir diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index fa7cb6c283e9..c694c8520ef4 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -22,6 +22,8 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/AffineMap.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -39,7 +41,6 @@ using namespace mlir; using namespace polygeist; using namespace mlir::arith; - llvm::cl::opt BarrierOpt("barrier-opt", llvm::cl::init(true), llvm::cl::desc("Optimize barriers")); @@ -5778,6 +5779,99 @@ struct SubMapOpCanonicalize : public OpRewritePattern { }; + struct LinalgOfSubmap : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + //Check body content + auto module = genericOp->getParentOfType(); + Region &genericBody = genericOp.getRegion(); + Block &entryBlock = genericBody.front(); + ValueRange blockArgs = entryBlock.getArguments(); + auto inputs = genericOp.getInputs(); + auto outputs = genericOp.getOutputs(); + SmallVector listOfAllocas; + SmallVector listOfNewMaps; + SmallVector listOfNewInputs, listOfNewOutputs; + //auto mapAttrsArr = genericOp.getIndexingMaps(); + //for(auto mapAttr: mapAttrsArr) { + // AffineMap map = mapAttr.cast().getValue(); + // if(map == convMap[0] && !mapped[0]) { + // } + //} + for(auto inp: inputs) { + if(auto blkArg = dyn_cast(inp)) { + listOfNewInputs.push_back(inp); + } + else if(auto subMap = dyn_cast(inp.getDefiningOp())) { + auto source_memref = subMap.getMemref(); + //if (auto blockArg = dyn_cast_or_null(op)) { + //if(auto source_alloca = dyn_cast(source_memref.getDefiningOp())) + //{ + auto map = subMap.getMap(); + listOfNewMaps.push_back(map); + listOfNewInputs.push_back(source_memref); + //} + //else { + // assert(false && "Only expect allocaOp as source for submap canonicalization right now"); + // return failure(); + //} + } + else { + listOfNewInputs.push_back(inp); + } + } + + for(auto out: outputs) { + if(auto blkArg = dyn_cast(out)) { + listOfNewOutputs.push_back(out); + } + else if(auto subMap = dyn_cast(out.getDefiningOp())) { + auto source_memref = subMap.getMemref(); + auto map = subMap.getMap(); + listOfNewMaps.push_back(map); + listOfNewOutputs.push_back(source_memref); + } + else { + listOfNewOutputs.push_back(out); + } + } + ArrayRef maps(listOfNewMaps); + //No submap ops detected + if(maps.size() == 0) + return failure(); + //If inverse permutation exists, then we can canonicalize the linalg of submap to linalg + //TODO: Fails for: + // 1. Maps with symbols + // 2. Maps with non + if(inversePermutation(concatAffineMaps(maps))) { + StringAttr empty = StringAttr::get(genericOp.getContext()); + auto newGenericOp = rewriter.create(genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, listOfNewMaps, genericOp.getIteratorTypesArray(), + empty, empty); + rewriter.inlineRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().end()); + + //auto &block = newGenericOp.getRegion().front(); + //block.addArguments(newGenericOp.getOperandTypes(), SmallVector(newGenericOp.getNumOperands(), genericOp.getLoc())); + + rewriter.replaceOp(genericOp, newGenericOp.getResults()); + return success(); + } + //for(iterate over inputs) + //{ + // gather maps + // gather submaps + // Gather affine maps from submaps + // Check over 2 iterations if all the indexes can be solved. + // Use the same logic as linalg.generic to do this. + // if success in getting vars + // replace affine map from submap to linalg.generic + // replace input memref as direct input to linalg.generic + //} + //assert(false && "inversePermutation doesn't exists for the given linalg generic"); + return failure(); + } + }; + // struct LinalgOfSubmap : public OpRewritePattern { // using OpRewritePattern::OpRewritePattern; // LogicalResult matchAndRewrite(linalg::GenericOp gen, @@ -6267,6 +6361,7 @@ class DimSubMap final : public OpRewritePattern { void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); + //results.insert(context); + results.insert(context); //results.insert(context); } diff --git a/test/polygeist-opt/raised_with_submap.mlir b/test/polygeist-opt/raised_with_submap.mlir new file mode 100644 index 000000000000..069e861445b1 --- /dev/null +++ b/test/polygeist-opt/raised_with_submap.mlir @@ -0,0 +1,883 @@ +#map = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0) -> (d0 * 3)> +#map2 = affine_map<(d0)[s0] -> (s0)> +#map3 = affine_map<(d0) -> (0)> +#map4 = affine_map<(d0, d1) -> (d1)> +#map5 = affine_map<(d0, d1) -> (d0)> +#map6 = affine_map<(d0, d1) -> (d0, d1)> +#map7 = affine_map<(d0, d1) -> (d0 * 2 + d1)> +#map8 = affine_map<(d0, d1) -> (d0 + d1 * 2)> +#map9 = affine_map<(d0, d1, d2) -> (d2)> +#map10 = affine_map<(d0, d1, d2) -> (d1)> +#map11 = affine_map<(d0, d1, d2) -> (d0)> +#map12 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map13 = affine_map<(d0, d1, d2) -> (d1 * 4 + d2 + 3)> +#map14 = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + 2)> +#map15 = affine_map<(d0, d1, d2) -> (d0 + d2 * 2)> +#map16 = affine_map<(d0, d1, d2) -> (d2, d0)> +#map17 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map18 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map19 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> +#map20 = affine_map<(d0, d1, d2, d3) -> (d1, d0)> +#map21 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +#map22 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map23 = affine_map<(d0, d1)[s0, s1] -> (d1 + s0, d0 + s1)> +#map24 = affine_map<(d0, d1) -> (d1, d0)> +#map25 = affine_map<(d0, d1)[s0, s1] -> (s0, s1)> +#map26 = affine_map<(d0)[s0, s1, s2] -> (s0 + s1, d0 + s2)> +#map27 = affine_map<(d0)[s0] -> (s0, d0)> +#map28 = affine_map<(d0)[s0, s1] -> (s0, s1)> +#map29 = affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 3)> +module { + module @constant_access { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 4.000000e+00 : f32 + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = arith.mulf %in, %cst : f32 + linalg.yield %5 : f32 + } + return + } + } +// module @constant_mem_access { +// func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { +// %c13 = arith.constant 13 : index +// %c4 = arith.constant 4 : index +// %0 = arith.index_cast %arg1 : i32 to index +// %1 = arith.muli %0, %c4 : index +// %2 = arith.divui %1, %c4 : index +// %alloca = memref.alloca(%2) : memref +// %3 = "polygeist.submap"(%arg2, %c13) <{map = #map1}> : (memref, index) -> memref +// %4 = "polygeist.submap"(%arg2, %c4, %c13) <{map = #map2}> : (memref, index, index) -> memref +// %5 = "polygeist.submap"(%alloca, %c13) <{map = #map}> : (memref, index) -> memref +// linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { +// ^bb0(%in: f32, %in_0: f32, %out: f32): +// %6 = arith.mulf %in, %in_0 : f32 +// linalg.yield %6 : f32 +// } +// return +// } +// } + module @no_if { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + return + } + } + module @arith_mul { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = arith.mulf %in, %in : f32 + linalg.yield %5 : f32 + } + return + } + } + module @arith_add { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17) <{map = #map}> : (memref, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.addf %in, %in_0 : f32 + %7 = arith.mulf %6, %6 : f32 + linalg.yield %7 : f32 + } + return + } + } + module @cond_arith { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = scf.if %arg0 -> (f32) { + %6 = arith.mulf %in, %in : f32 + scf.yield %6 : f32 + } else { + scf.yield %in : f32 + } + linalg.yield %5 : f32 + } + return + } + } + module @reduction { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map3}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %out, %in : f32 + linalg.yield %5 : f32 + } + return + } + } + module @reduction_transformed { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %cst = arith.constant 0.000000e+00 : f32 + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %alloca_0 = memref.alloca() : memref<1xf32> + affine.store %cst, %alloca_0[0] : memref<1xf32> + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca_0, %c17) <{map = #map3}> : (memref<1xf32>, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %6 = arith.addf %out, %in : f32 + linalg.yield %6 : f32 + } + %5 = affine.load %alloca_0[0] : memref<1xf32> + affine.store %5, %alloca[0] : memref + return + } + } + module @reduction_transformed_simplified { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %cst = arith.constant 0.000000e+00 : f32 + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.store %cst, %alloca[0] : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map3}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %out, %in : f32 + linalg.yield %5 : f32 + } + return + } + } + module @cond_store_1 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.for %arg3 = 0 to 17 { + %3 = affine.load %arg2[%arg3] : memref + %4 = arith.mulf %3, %3 : f32 + scf.if %arg0 { + affine.store %4, %alloca[%arg3] : memref + } + } + return + } + } + module @cond_store_2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.for %arg3 = 0 to 17 { + %3 = affine.load %arg2[%arg3] : memref + scf.if %arg0 { + %4 = arith.mulf %3, %3 : f32 + affine.store %4, %alloca[%arg3] : memref + } else { + affine.store %3, %alloca[%arg3] : memref + } + } + return + } + } + module @for_within_for { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map4}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_within_for_2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map7}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_within_for_3 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map7}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map4}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %6 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4, %5 : memref, memref, memref) outs(%6 : memref) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %7 = arith.mulf %in, %in_0 : f32 + %8 = arith.mulf %7, %in_1 : f32 + linalg.yield %8 : f32 + } + return + } + } + module @for_within_for_4 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map8}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_no_loop_dependency { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c15 = arith.constant 15 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c15) <{map = #map3}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c15) <{map = #map3}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + return + } + } + module @for_2_levels_no_loop_dependency { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c17 = arith.constant 17 : index + %c15 = arith.constant 15 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c15, %c17) <{map = #map4}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%alloca, %c15, %c17) <{map = #map4}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + return + } + } + module @for_3_levels_0 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c15 = arith.constant 15 : index + %c17 = arith.constant 17 : index + %c21 = arith.constant 21 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c21, %c17, %c15) <{map = #map9}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c21, %c17, %c15) <{map = #map10}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c21, %c17, %c15) <{map = #map11}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12], iterator_types = ["reduction", "reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_3_levels_1 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21, %c21) <{map = #map11}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21, %c21) <{map = #map11}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12], iterator_types = ["reduction", "reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_3_levels_2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c21, %c17, %c21) <{map = #map9}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%arg4, %c21, %c17, %c21) <{map = #map11}> : (memref, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12, #map12], iterator_types = ["reduction", "parallel", "reduction"]} ins(%3, %4, %5 : memref, memref, memref) outs(%6 : memref) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %7 = arith.mulf %in, %in_0 : f32 + %8 = arith.mulf %7, %in_1 : f32 + linalg.yield %8 : f32 + } + return + } + } + module @for_3_levels_3 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c21, %c17, %c21) <{map = #map9}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map11}> : (memref, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12, #map12], iterator_types = ["reduction", "parallel", "reduction"]} ins(%3, %4, %5 : memref, memref, memref) outs(%6 : memref) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %7 = arith.mulf %in, %in_0 : f32 + %8 = arith.mulf %7, %in_1 : f32 + linalg.yield %8 : f32 + } + return + } + } + module @for_3_levels_4 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c21, %c17, %c21) <{map = #map13}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map14}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map15}> : (memref, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12, #map12], iterator_types = ["reduction", "parallel", "reduction"]} ins(%3, %4, %5 : memref, memref, memref) outs(%6 : memref) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %7 = arith.mulf %in, %in_0 : f32 + %8 = arith.mulf %7, %in_1 : f32 + linalg.yield %8 : f32 + } + return + } + } + module @for_within_for2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map4}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @matmul_1 { + memref.global @out : memref<32x8xi32> = uninitialized + memref.global @im2 : memref<8x8xi32> = uninitialized + memref.global @im1 : memref<32x8xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im1 : memref<32x8xi32> + %1 = memref.get_global @im2 : memref<8x8xi32> + %2 = memref.get_global @out : memref<32x8xi32> + %3 = "polygeist.submap"(%0, %c8, %c8, %c32) <{map = #map16}> : (memref<32x8xi32>, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c8, %c8, %c32) <{map = #map17}> : (memref<8x8xi32>, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c8, %c8, %c32) <{map = #map18}> : (memref<32x8xi32>, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @matmul_2 { + memref.global @out : memref<128x32xi32> = uninitialized + memref.global @im2 : memref<64x32xi32> = uninitialized + memref.global @im1 : memref<128x64xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im1 : memref<128x64xi32> + %1 = memref.get_global @im2 : memref<64x32xi32> + %2 = memref.get_global @out : memref<128x32xi32> + %3 = "polygeist.submap"(%0, %c64, %c32, %c128) <{map = #map16}> : (memref<128x64xi32>, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c64, %c32, %c128) <{map = #map17}> : (memref<64x32xi32>, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c64, %c32, %c128) <{map = #map18}> : (memref<128x32xi32>, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @conv_1 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = "polygeist.submap"(%0, %c4, %c4, %c64, %c512) <{map = #map19}> : (memref<515x67xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c4, %c4, %c64, %c512) <{map = #map20}> : (memref<4x4xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c4, %c4, %c64, %c512) <{map = #map21}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @conv_1_reduction_test { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main(%arg0: index, %arg1: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = "polygeist.submap"(%0, %arg0, %arg1, %c4, %c4) <{map = #map23}> : (memref<515x67xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c4, %c4) <{map = #map24}> : (memref<4x4xi32>, index, index) -> memref + %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4, %c4) <{map = #map25}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @conv_2 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = "polygeist.submap"(%0, %c4, %c4, %c64, %c512) <{map = #map19}> : (memref<515x67xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c4, %c4, %c64, %c512) <{map = #map20}> : (memref<4x4xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c4, %c4, %c64, %c512) <{map = #map21}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @box_filter { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @out : memref<512x64xi32> + %2 = "polygeist.submap"(%0, %c4, %c4, %c64, %c512) <{map = #map19}> : (memref<515x67xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%1, %c4, %c4, %c64, %c512) <{map = #map21}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2 : memref) outs(%3 : memref) { + ^bb0(%in: i32, %out: i32): + %4 = arith.addi %out, %in : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } + } + module @conv_loop1_test { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main(%arg0: index, %arg1: index, %arg2: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = "polygeist.submap"(%0, %arg0, %arg2, %arg1, %c4) <{map = #map26}> : (memref<515x67xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %arg2, %c4) <{map = #map27}> : (memref<4x4xi32>, index, index) -> memref + %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4) <{map = #map28}> : (memref<512x64xi32>, index, index, index) -> memref + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @submap_test { + memref.global @out : memref<511x64xi32> = uninitialized + memref.global @filter : memref<5x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main(%arg0: index, %arg1: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<5x4xi32> + %2 = memref.get_global @out : memref<511x64xi32> + %3 = "polygeist.submap"(%0, %arg0, %arg1, %c4, %c5) <{map = #map23}> : (memref<515x67xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c4, %c5) <{map = #map24}> : (memref<5x4xi32>, index, index) -> memref + %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4, %c5) <{map = #map25}> : (memref<511x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @harris_score_1 { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + %5 = "polygeist.submap"(%2, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%3, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %7 = "polygeist.submap"(%4, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%5, %6, %7 : memref, memref, memref) outs(%8, %9 : memref, memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32, %out_2: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.addi %out, %23 : i32 + %25 = arith.muli %in, %in_1 : i32 + %26 = arith.addi %out_2, %25 : i32 + linalg.yield %24, %26 : i32, i32 + } + %10 = memref.get_global @img_ixx : memref<512x512xi32> + %11 = memref.get_global @img_iyy : memref<512x512xi32> + %12 = memref.get_global @img_ixy : memref<512x512xi32> + %13 = "polygeist.submap"(%0, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %14 = "polygeist.submap"(%1, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %15 = "polygeist.submap"(%10, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %16 = "polygeist.submap"(%11, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %17 = "polygeist.submap"(%12, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%13, %14 : memref, memref) outs(%15, %16, %17 : memref, memref, memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32, %out_1: i32, %out_2: i32): + %23 = arith.muli %in, %in : i32 + %24 = arith.addi %out, %23 : i32 + %25 = arith.muli %in_0, %in_0 : i32 + %26 = arith.addi %out_1, %25 : i32 + %27 = arith.muli %in, %in_0 : i32 + %28 = arith.addi %out_2, %27 : i32 + linalg.yield %24, %26, %28 : i32, i32, i32 + } + %18 = memref.get_global @score : memref<512x512xi32> + %19 = "polygeist.submap"(%10, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %20 = "polygeist.submap"(%11, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %21 = "polygeist.submap"(%12, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %22 = "polygeist.submap"(%18, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%19, %20, %21 : memref, memref, memref) outs(%22 : memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.muli %in_1, %in_1 : i32 + %25 = arith.subi %23, %24 : i32 + %26 = arith.addi %in, %in_0 : i32 + %27 = arith.muli %26, %c4_i32 : i32 + %28 = arith.muli %27, %26 : i32 + %29 = arith.subi %25, %28 : i32 + linalg.yield %29 : i32 + } + return %c0_i32 : i32 + } + } + module @harris_score_2 { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + %5 = "polygeist.submap"(%2, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%3, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %7 = "polygeist.submap"(%4, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%5, %6, %7 : memref, memref, memref) outs(%8, %9 : memref, memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32, %out_2: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.addi %out_2, %23 : i32 + %25 = arith.muli %in, %in_1 : i32 + %26 = arith.addi %out, %25 : i32 + linalg.yield %26, %24 : i32, i32 + } + %10 = memref.get_global @img_ixx : memref<512x512xi32> + %11 = memref.get_global @img_iyy : memref<512x512xi32> + %12 = memref.get_global @img_ixy : memref<512x512xi32> + %13 = "polygeist.submap"(%0, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %14 = "polygeist.submap"(%1, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %15 = "polygeist.submap"(%12, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %16 = "polygeist.submap"(%11, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %17 = "polygeist.submap"(%10, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%13, %14 : memref, memref) outs(%15, %16, %17 : memref, memref, memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32, %out_1: i32, %out_2: i32): + %23 = arith.muli %in, %in : i32 + %24 = arith.addi %out_2, %23 : i32 + %25 = arith.muli %in_0, %in_0 : i32 + %26 = arith.addi %out_1, %25 : i32 + %27 = arith.muli %in, %in_0 : i32 + %28 = arith.addi %out, %27 : i32 + linalg.yield %28, %26, %24 : i32, i32, i32 + } + %18 = memref.get_global @score : memref<512x512xi32> + %19 = "polygeist.submap"(%10, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %20 = "polygeist.submap"(%11, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %21 = "polygeist.submap"(%12, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %22 = "polygeist.submap"(%18, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%19, %20, %21 : memref, memref, memref) outs(%22 : memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.muli %in_1, %in_1 : i32 + %25 = arith.subi %23, %24 : i32 + %26 = arith.addi %in, %in_0 : i32 + %27 = arith.muli %26, %c4_i32 : i32 + %28 = arith.muli %27, %26 : i32 + %29 = arith.subi %25, %28 : i32 + linalg.yield %29 : i32 + } + return %c0_i32 : i32 + } + } + module @harris_score_local { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @coeffs_x : memref<9xi32> + %1 = memref.get_global @coeffs_y : memref<9xi32> + %2 = "polygeist.submap"(%alloca_4, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca_3, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca_2, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2, %3, %4 : memref, memref, memref) outs(%5, %6 : memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.addi %out, %17 : i32 + %19 = arith.muli %in, %in_6 : i32 + %20 = arith.addi %out_7, %19 : i32 + linalg.yield %18, %20 : i32, i32 + } + %7 = "polygeist.submap"(%alloca_3, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%alloca_2, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%alloca, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %10 = "polygeist.submap"(%alloca_0, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %11 = "polygeist.submap"(%alloca_1, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%7, %8 : memref, memref) outs(%9, %10, %11 : memref, memref, memref) { + ^bb0(%in: i32, %in_5: i32, %out: i32, %out_6: i32, %out_7: i32): + %17 = arith.muli %in, %in : i32 + %18 = arith.addi %out_7, %17 : i32 + %19 = arith.muli %in_5, %in_5 : i32 + %20 = arith.addi %out_6, %19 : i32 + %21 = arith.muli %in, %in_5 : i32 + %22 = arith.addi %out, %21 : i32 + linalg.yield %22, %20, %18 : i32, i32, i32 + } + %12 = memref.get_global @score : memref<512x512xi32> + %13 = "polygeist.submap"(%alloca_1, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %14 = "polygeist.submap"(%alloca_0, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %15 = "polygeist.submap"(%alloca, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %16 = "polygeist.submap"(%12, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%13, %14, %15 : memref, memref, memref) outs(%16 : memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.muli %in_6, %in_6 : i32 + %19 = arith.subi %17, %18 : i32 + %20 = arith.addi %in, %in_5 : i32 + %21 = arith.muli %20, %c4_i32 : i32 + %22 = arith.muli %21, %20 : i32 + %23 = arith.subi %19, %22 : i32 + linalg.yield %23 : i32 + } + return %c0_i32 : i32 + } + } +} + From 532773a2c7ecb3b084bce73ac876a64dbedb2553 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 25 Oct 2024 00:34:41 -0700 Subject: [PATCH 32/77] Added modified 2d kernel for harris score- raised successfully to linalg on memref --- test/polygeist-opt/raised_with_submap.mlir | 275 ++++++++++++++++----- 1 file changed, 208 insertions(+), 67 deletions(-) diff --git a/test/polygeist-opt/raised_with_submap.mlir b/test/polygeist-opt/raised_with_submap.mlir index 069e861445b1..9e70e07e9bcc 100644 --- a/test/polygeist-opt/raised_with_submap.mlir +++ b/test/polygeist-opt/raised_with_submap.mlir @@ -557,28 +557,28 @@ module { return %c0_i32 : i32 } } - module @conv_1_reduction_test { - memref.global @out : memref<512x64xi32> = uninitialized - memref.global @filter : memref<4x4xi32> = uninitialized - memref.global @im : memref<515x67xi32> = uninitialized - func.func @main(%arg0: index, %arg1: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c4 = arith.constant 4 : index - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im : memref<515x67xi32> - %1 = memref.get_global @filter : memref<4x4xi32> - %2 = memref.get_global @out : memref<512x64xi32> - %3 = "polygeist.submap"(%0, %arg0, %arg1, %c4, %c4) <{map = #map23}> : (memref<515x67xi32>, index, index, index, index) -> memref - %4 = "polygeist.submap"(%1, %c4, %c4) <{map = #map24}> : (memref<4x4xi32>, index, index) -> memref - %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4, %c4) <{map = #map25}> : (memref<512x64xi32>, index, index, index, index) -> memref - linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { - ^bb0(%in: i32, %in_0: i32, %out: i32): - %6 = arith.muli %in, %in_0 : i32 - %7 = arith.addi %out, %6 : i32 - linalg.yield %7 : i32 - } - return %c0_i32 : i32 - } - } + // module @conv_1_reduction_test { + // memref.global @out : memref<512x64xi32> = uninitialized + // memref.global @filter : memref<4x4xi32> = uninitialized + // memref.global @im : memref<515x67xi32> = uninitialized + // func.func @main(%arg0: index, %arg1: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + // %c4 = arith.constant 4 : index + // %c0_i32 = arith.constant 0 : i32 + // %0 = memref.get_global @im : memref<515x67xi32> + // %1 = memref.get_global @filter : memref<4x4xi32> + // %2 = memref.get_global @out : memref<512x64xi32> + // %3 = "polygeist.submap"(%0, %arg0, %arg1, %c4, %c4) <{map = #map23}> : (memref<515x67xi32>, index, index, index, index) -> memref + // %4 = "polygeist.submap"(%1, %c4, %c4) <{map = #map24}> : (memref<4x4xi32>, index, index) -> memref + // %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4, %c4) <{map = #map25}> : (memref<512x64xi32>, index, index, index, index) -> memref + // linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + // ^bb0(%in: i32, %in_0: i32, %out: i32): + // %6 = arith.muli %in, %in_0 : i32 + // %7 = arith.addi %out, %6 : i32 + // linalg.yield %7 : i32 + // } + // return %c0_i32 : i32 + // } + // } module @conv_2 { memref.global @out : memref<512x64xi32> = uninitialized memref.global @filter : memref<4x4xi32> = uninitialized @@ -624,51 +624,51 @@ module { return %c0_i32 : i32 } } - module @conv_loop1_test { - memref.global @out : memref<512x64xi32> = uninitialized - memref.global @filter : memref<4x4xi32> = uninitialized - memref.global @im : memref<515x67xi32> = uninitialized - func.func @main(%arg0: index, %arg1: index, %arg2: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c4 = arith.constant 4 : index - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im : memref<515x67xi32> - %1 = memref.get_global @filter : memref<4x4xi32> - %2 = memref.get_global @out : memref<512x64xi32> - %3 = "polygeist.submap"(%0, %arg0, %arg2, %arg1, %c4) <{map = #map26}> : (memref<515x67xi32>, index, index, index, index) -> memref - %4 = "polygeist.submap"(%1, %arg2, %c4) <{map = #map27}> : (memref<4x4xi32>, index, index) -> memref - %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4) <{map = #map28}> : (memref<512x64xi32>, index, index, index) -> memref - linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { - ^bb0(%in: i32, %in_0: i32, %out: i32): - %6 = arith.muli %in, %in_0 : i32 - %7 = arith.addi %out, %6 : i32 - linalg.yield %7 : i32 - } - return %c0_i32 : i32 - } - } - module @submap_test { - memref.global @out : memref<511x64xi32> = uninitialized - memref.global @filter : memref<5x4xi32> = uninitialized - memref.global @im : memref<515x67xi32> = uninitialized - func.func @main(%arg0: index, %arg1: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c5 = arith.constant 5 : index - %c4 = arith.constant 4 : index - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im : memref<515x67xi32> - %1 = memref.get_global @filter : memref<5x4xi32> - %2 = memref.get_global @out : memref<511x64xi32> - %3 = "polygeist.submap"(%0, %arg0, %arg1, %c4, %c5) <{map = #map23}> : (memref<515x67xi32>, index, index, index, index) -> memref - %4 = "polygeist.submap"(%1, %c4, %c5) <{map = #map24}> : (memref<5x4xi32>, index, index) -> memref - %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4, %c5) <{map = #map25}> : (memref<511x64xi32>, index, index, index, index) -> memref - linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { - ^bb0(%in: i32, %in_0: i32, %out: i32): - %6 = arith.muli %in, %in_0 : i32 - %7 = arith.addi %out, %6 : i32 - linalg.yield %7 : i32 - } - return %c0_i32 : i32 - } - } +// module @conv_loop1_test { +// memref.global @out : memref<512x64xi32> = uninitialized +// memref.global @filter : memref<4x4xi32> = uninitialized +// memref.global @im : memref<515x67xi32> = uninitialized +// func.func @main(%arg0: index, %arg1: index, %arg2: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c4 = arith.constant 4 : index +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im : memref<515x67xi32> +// %1 = memref.get_global @filter : memref<4x4xi32> +// %2 = memref.get_global @out : memref<512x64xi32> +// %3 = "polygeist.submap"(%0, %arg0, %arg2, %arg1, %c4) <{map = #map26}> : (memref<515x67xi32>, index, index, index, index) -> memref +// %4 = "polygeist.submap"(%1, %arg2, %c4) <{map = #map27}> : (memref<4x4xi32>, index, index) -> memref +// %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4) <{map = #map28}> : (memref<512x64xi32>, index, index, index) -> memref +// linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { +// ^bb0(%in: i32, %in_0: i32, %out: i32): +// %6 = arith.muli %in, %in_0 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7 : i32 +// } +// return %c0_i32 : i32 +// } +// } +// module @submap_test { +// memref.global @out : memref<511x64xi32> = uninitialized +// memref.global @filter : memref<5x4xi32> = uninitialized +// memref.global @im : memref<515x67xi32> = uninitialized +// func.func @main(%arg0: index, %arg1: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c5 = arith.constant 5 : index +// %c4 = arith.constant 4 : index +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im : memref<515x67xi32> +// %1 = memref.get_global @filter : memref<5x4xi32> +// %2 = memref.get_global @out : memref<511x64xi32> +// %3 = "polygeist.submap"(%0, %arg0, %arg1, %c4, %c5) <{map = #map23}> : (memref<515x67xi32>, index, index, index, index) -> memref +// %4 = "polygeist.submap"(%1, %c4, %c5) <{map = #map24}> : (memref<5x4xi32>, index, index) -> memref +// %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4, %c5) <{map = #map25}> : (memref<511x64xi32>, index, index, index, index) -> memref +// linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { +// ^bb0(%in: i32, %in_0: i32, %out: i32): +// %6 = arith.muli %in, %in_0 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7 : i32 +// } +// return %c0_i32 : i32 +// } +// } module @harris_score_1 { memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> @@ -881,3 +881,144 @@ module { } } +module @harris_score_2d_kernel { + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + %2 = "polygeist.submap"(%alloca_4, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca_2, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca_3, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2, %3, %4 : memref, memref, memref) outs(%5, %6 : memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.addi %out_7, %17 : i32 + %19 = arith.muli %in, %in_6 : i32 + %20 = arith.addi %out, %19 : i32 + linalg.yield %20, %18 : i32, i32 + } + %7 = "polygeist.submap"(%alloca_3, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%alloca_2, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%alloca, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %10 = "polygeist.submap"(%alloca_0, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %11 = "polygeist.submap"(%alloca_1, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%7, %8 : memref, memref) outs(%9, %10, %11 : memref, memref, memref) { + ^bb0(%in: i32, %in_5: i32, %out: i32, %out_6: i32, %out_7: i32): + %17 = arith.muli %in, %in : i32 + %18 = arith.addi %out_7, %17 : i32 + %19 = arith.muli %in_5, %in_5 : i32 + %20 = arith.addi %out_6, %19 : i32 + %21 = arith.muli %in, %in_5 : i32 + %22 = arith.addi %out, %21 : i32 + linalg.yield %22, %20, %18 : i32, i32, i32 + } + %12 = memref.get_global @score : memref<512x512xi32> + %13 = "polygeist.submap"(%alloca_1, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %14 = "polygeist.submap"(%alloca_0, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %15 = "polygeist.submap"(%alloca, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %16 = "polygeist.submap"(%12, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%13, %14, %15 : memref, memref, memref) outs(%16 : memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.muli %in_6, %in_6 : i32 + %19 = arith.subi %17, %18 : i32 + %20 = arith.addi %in, %in_5 : i32 + %21 = arith.muli %20, %c4_i32 : i32 + %22 = arith.muli %21, %20 : i32 + %23 = arith.subi %19, %22 : i32 + linalg.yield %23 : i32 + } + return %c0_i32 : i32 + } +} + +module @harris_score_gradient_1d_kernel { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + %5 = "polygeist.submap"(%2, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%3, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %7 = "polygeist.submap"(%4, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%5, %6, %7 : memref, memref, memref) outs(%8, %9 : memref, memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32, %out_2: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.addi %out, %23 : i32 + %25 = arith.muli %in, %in_1 : i32 + %26 = arith.addi %out_2, %25 : i32 + linalg.yield %24, %26 : i32, i32 + } + return %c0_i32 : i32 + } +} + +module @harris_score_gradient_2d_kernel { + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + %2 = "polygeist.submap"(%alloca_4, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca_2, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca_3, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2, %3, %4 : memref, memref, memref) outs(%5, %6 : memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.addi %out_7, %17 : i32 + %19 = arith.muli %in, %in_6 : i32 + %20 = arith.addi %out, %19 : i32 + linalg.yield %20, %18 : i32, i32 + } + return %c0_i32 : i32 + } +} \ No newline at end of file From e2b4b2dac9a7b6a15b025393fd0d2ba06d80cd6b Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 25 Oct 2024 00:42:40 -0700 Subject: [PATCH 33/77] Added harris score kernel with gradient kernel- just to be able to raise to linalg --- test/polygeist-opt/linalgraise.mlir | 156 +++++++++++++++++++++ test/polygeist-opt/raised_with_submap.mlir | 75 +++++++++- 2 files changed, 230 insertions(+), 1 deletion(-) diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index a05bd5338122..0d6b0dd61fc0 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -1089,3 +1089,159 @@ module @harris_score_local { return %c0_i32 : i32 } } + +module @harris_score_2d_kernel { + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + %3:2 = affine.for %arg2 = 0 to 3 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32) -> (i32, i32) { + %4:2 = affine.for %arg5 = 0 to 3 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (i32, i32) { + %5 = affine.load %alloca_4[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %6 = affine.load %0[%arg2, %arg5] : memref<3x3xi32> + %7 = arith.muli %5, %6 : i32 + %8 = arith.addi %arg7, %7 : i32 + %9 = affine.load %1[%arg2, %arg5] : memref<3x3xi32> + %10 = arith.muli %5, %9 : i32 + %11 = arith.addi %arg6, %10 : i32 + affine.yield %11, %8 : i32, i32 + } + affine.yield %4#0, %4#1 : i32, i32 + } + affine.store %3#1, %alloca_3[%arg0, %arg1] : memref<516x516xi32> + affine.store %3#0, %alloca_2[%arg0, %arg1] : memref<516x516xi32> + } + } + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %3:3 = affine.for %arg2 = 0 to 5 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32) -> (i32, i32, i32) { + %4:3 = affine.for %arg6 = 0 to 5 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (i32, i32, i32) { + %5 = affine.load %alloca_3[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %6 = affine.load %alloca_2[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %7 = arith.muli %5, %5 : i32 + %8 = arith.addi %arg9, %7 : i32 + %9 = arith.muli %6, %6 : i32 + %10 = arith.addi %arg8, %9 : i32 + %11 = arith.muli %5, %6 : i32 + %12 = arith.addi %arg7, %11 : i32 + affine.yield %12, %10, %8 : i32, i32, i32 + } + affine.yield %4#0, %4#1, %4#2 : i32, i32, i32 + } + affine.store %3#2, %alloca_1[%arg0, %arg1] : memref<512x512xi32> + affine.store %3#1, %alloca_0[%arg0, %arg1] : memref<512x512xi32> + affine.store %3#0, %alloca[%arg0, %arg1] : memref<512x512xi32> + } + } + %2 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %3 = affine.load %alloca_1[%arg0, %arg1] : memref<512x512xi32> + %4 = affine.load %alloca_0[%arg0, %arg1] : memref<512x512xi32> + %5 = affine.load %alloca[%arg0, %arg1] : memref<512x512xi32> + %6 = arith.muli %3, %4 : i32 + %7 = arith.muli %5, %5 : i32 + %8 = arith.subi %6, %7 : i32 + %9 = arith.addi %3, %4 : i32 + %10 = arith.muli %9, %c4_i32 : i32 + %11 = arith.muli %10, %9 : i32 + %12 = arith.subi %8, %11 : i32 + affine.store %12, %2[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} + +module @harris_score_with_gradient_extra_kernel { + memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + %4:2 = affine.for %arg2 = 0 to 3 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32) -> (i32, i32) { + %5:2 = affine.for %arg5 = 0 to 3 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (i32, i32) { + %6 = affine.load %alloca_4[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %7 = affine.load %0[%arg2, %arg5] : memref<3x3xi32> + %8 = arith.muli %6, %7 : i32 + %9 = arith.addi %arg7, %8 : i32 + %10 = affine.load %1[%arg2, %arg5] : memref<3x3xi32> + %11 = arith.muli %6, %10 : i32 + %12 = arith.addi %arg6, %11 : i32 + affine.yield %12, %9 : i32, i32 + } + affine.yield %5#0, %5#1 : i32, i32 + } + affine.store %4#1, %alloca_3[%arg0, %arg1] : memref<516x516xi32> + affine.store %4#0, %alloca_2[%arg0, %arg1] : memref<516x516xi32> + } + } + %2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %4:3 = affine.for %arg2 = 0 to 5 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32) -> (i32, i32, i32) { + %5:3 = affine.for %arg6 = 0 to 5 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (i32, i32, i32) { + %6 = affine.load %alloca_3[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %7 = affine.load %alloca_2[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %8 = arith.muli %6, %6 : i32 + %9 = affine.load %2[%arg2, %arg6] : memref<5x5xi32> + %10 = arith.muli %8, %9 : i32 + %11 = arith.addi %arg9, %10 : i32 + %12 = arith.muli %7, %7 : i32 + %13 = arith.muli %12, %9 : i32 + %14 = arith.addi %arg8, %13 : i32 + %15 = arith.muli %6, %7 : i32 + %16 = arith.muli %15, %9 : i32 + %17 = arith.addi %arg7, %16 : i32 + affine.yield %17, %14, %11 : i32, i32, i32 + } + affine.yield %5#0, %5#1, %5#2 : i32, i32, i32 + } + affine.store %4#2, %alloca_1[%arg0, %arg1] : memref<512x512xi32> + affine.store %4#1, %alloca_0[%arg0, %arg1] : memref<512x512xi32> + affine.store %4#0, %alloca[%arg0, %arg1] : memref<512x512xi32> + } + } + %3 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %4 = affine.load %alloca_1[%arg0, %arg1] : memref<512x512xi32> + %5 = affine.load %alloca_0[%arg0, %arg1] : memref<512x512xi32> + %6 = affine.load %alloca[%arg0, %arg1] : memref<512x512xi32> + %7 = arith.muli %4, %5 : i32 + %8 = arith.muli %6, %6 : i32 + %9 = arith.subi %7, %8 : i32 + %10 = arith.addi %4, %5 : i32 + %11 = arith.muli %10, %c4_i32 : i32 + %12 = arith.muli %11, %10 : i32 + %13 = arith.subi %9, %12 : i32 + affine.store %13, %3[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} \ No newline at end of file diff --git a/test/polygeist-opt/raised_with_submap.mlir b/test/polygeist-opt/raised_with_submap.mlir index 9e70e07e9bcc..f126b738d0f1 100644 --- a/test/polygeist-opt/raised_with_submap.mlir +++ b/test/polygeist-opt/raised_with_submap.mlir @@ -1021,4 +1021,77 @@ module @harris_score_gradient_2d_kernel { } return %c0_i32 : i32 } -} \ No newline at end of file +} + +module @harris_score_with_gradient_extra_kernel { + memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + %2 = "polygeist.submap"(%alloca_4, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca_2, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca_3, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2, %3, %4 : memref, memref, memref) outs(%5, %6 : memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %19 = arith.muli %in, %in_5 : i32 + %20 = arith.addi %out_7, %19 : i32 + %21 = arith.muli %in, %in_6 : i32 + %22 = arith.addi %out, %21 : i32 + linalg.yield %22, %20 : i32, i32 + } + %7 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + %8 = "polygeist.submap"(%alloca_3, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%alloca_2, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %10 = "polygeist.submap"(%7, %c5, %c5, %c512, %c512) <{map = #map20}> : (memref<5x5xi32>, index, index, index, index) -> memref + %11 = "polygeist.submap"(%alloca, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %12 = "polygeist.submap"(%alloca_0, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %13 = "polygeist.submap"(%alloca_1, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%8, %9, %10 : memref, memref, memref) outs(%11, %12, %13 : memref, memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): + %19 = arith.muli %in, %in : i32 + %20 = arith.muli %19, %in_6 : i32 + %21 = arith.addi %out_8, %20 : i32 + %22 = arith.muli %in_5, %in_5 : i32 + %23 = arith.muli %22, %in_6 : i32 + %24 = arith.addi %out_7, %23 : i32 + %25 = arith.muli %in, %in_5 : i32 + %26 = arith.muli %25, %in_6 : i32 + %27 = arith.addi %out, %26 : i32 + linalg.yield %27, %24, %21 : i32, i32, i32 + } + %14 = memref.get_global @score : memref<512x512xi32> + %15 = "polygeist.submap"(%alloca_1, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %16 = "polygeist.submap"(%alloca_0, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %17 = "polygeist.submap"(%alloca, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %18 = "polygeist.submap"(%14, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%15, %16, %17 : memref, memref, memref) outs(%18 : memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %19 = arith.muli %in, %in_5 : i32 + %20 = arith.muli %in_6, %in_6 : i32 + %21 = arith.subi %19, %20 : i32 + %22 = arith.addi %in, %in_5 : i32 + %23 = arith.muli %22, %c4_i32 : i32 + %24 = arith.muli %23, %22 : i32 + %25 = arith.subi %21, %24 : i32 + linalg.yield %25 : i32 + } + return %c0_i32 : i32 + } +} From f2ab09e0018a0c42a5d1c7fffd93507de1feafea Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Mon, 13 Jan 2025 11:09:45 -0800 Subject: [PATCH 34/77] Initial working implementation of debufferize flow for linalg with examples --- debufferize.mlir | 39 ++++ include/polygeist/Passes/Passes.h | 9 + include/polygeist/Passes/Passes.td | 11 + lib/polygeist/Ops.cpp | 2 +- lib/polygeist/Passes/CMakeLists.txt | 1 + lib/polygeist/Passes/LinalgDebufferize.cpp | 224 +++++++++++++++++++++ 6 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 debufferize.mlir create mode 100644 lib/polygeist/Passes/LinalgDebufferize.cpp diff --git a/debufferize.mlir b/debufferize.mlir new file mode 100644 index 000000000000..3e310644f4bc --- /dev/null +++ b/debufferize.mlir @@ -0,0 +1,39 @@ +//polygeist-opt --linalg-debufferize debufferize.mlir + +#map16 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map17 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> +#map18 = affine_map<(d0, d1, d2, d3) -> (d1, d0)> +#map19 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + + module @in_place_add{ + func.func @in_place_add(%value: f32) { + %c0 = arith.constant 0 : index + %buffer = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + +module @conv_2 { + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.alloca() : memref<515x67xi32> + %1 = memref.alloca() : memref<4x4xi32> + %2 = memref.alloca() : memref<512x64xi32> + linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %3 = arith.muli %in, %in_0 : i32 + %4 = arith.addi %out, %3 : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } +} \ No newline at end of file diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index ad7e2fc14fc2..7a95484a2fdb 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -32,6 +32,7 @@ std::unique_ptr createOpenMPOptPass(); std::unique_ptr createCanonicalizeForPass(); std::unique_ptr createRaiseSCFToAffinePass(); std::unique_ptr createRaiseAffineToLinalgPass(); +std::unique_ptr createLinalgDebufferizePass(); std::unique_ptr createRemoveIterArgsPass(); std::unique_ptr createCPUifyPass(StringRef method = ""); std::unique_ptr createBarrierRemovalContinuation(); @@ -129,6 +130,14 @@ namespace linalg { class LinalgDialect; } +namespace bufferization { +class BufferizationDialect; +} + +namespace Tensor { +class TensorDialect; +} + namespace LLVM { class LLVMDialect; } diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 7d5f2315f4ce..5b8251c616b8 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -160,6 +160,17 @@ def RemoveIterArgs : Pass<"remove-iter-args"> { ]; } +def LinalgDebufferize : Pass<"linalg-debufferize"> { + let summary = "Raise affine to linalg"; + let constructor = "mlir::polygeist::createLinalgDebufferizePass()"; + let dependentDialects = [ + "affine::AffineDialect", + "linalg::LinalgDialect", + "bufferization::BufferizationDialect", + "polygeist::PolygeistDialect", + ]; +} + def AffineRaiseToLinalg : Pass<"raise-affine-to-linalg"> { let summary = "Raise affine to linalg"; let constructor = "mlir::polygeist::createRaiseAffineToLinalgPass()"; diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index c694c8520ef4..4010e58330cb 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -5843,7 +5843,7 @@ struct SubMapOpCanonicalize : public OpRewritePattern { //If inverse permutation exists, then we can canonicalize the linalg of submap to linalg //TODO: Fails for: // 1. Maps with symbols - // 2. Maps with non + // 2. Maps which are not resolvable 1 to 1 with memref for all dims if(inversePermutation(concatAffineMaps(maps))) { StringAttr empty = StringAttr::get(genericOp.getContext()); auto newGenericOp = rewriter.create(genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, listOfNewMaps, genericOp.getIteratorTypesArray(), diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index f98813fb15b5..ae74300af7a1 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms RaiseToAffine.cpp RemoveIterArgs.cpp RaiseToLinalg.cpp + LinalgDebufferize.cpp ParallelLower.cpp TrivialUse.cpp ConvertPolygeistToLLVM.cpp diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp new file mode 100644 index 000000000000..c5e04a67af5b --- /dev/null +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -0,0 +1,224 @@ +#include "PassDetails.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "polygeist/Passes/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-debufferize" + +using namespace mlir; +using namespace mlir::arith; +using namespace polygeist; +using namespace affine; +using namespace linalg; +using namespace tensor; +using namespace bufferization; + + + +//module @harris_score_with_gradient_extra_kernel { +// memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> +// memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> +// memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> +// memref.global @score : memref<512x512xi32> = uninitialized +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c4_i32 = arith.constant 4 : i32 +// %c0_i32 = arith.constant 0 : i32 +// %alloca = memref.alloca() : memref<512x512xi32> +// %alloca_0 = memref.alloca() : memref<512x512xi32> +// %alloca_1 = memref.alloca() : memref<512x512xi32> +// %alloca_2 = memref.alloca() : memref<516x516xi32> +// %alloca_3 = memref.alloca() : memref<516x516xi32> +// %alloca_4 = memref.alloca() : memref<518x518xi32> +// %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> +// %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> +// %2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> +// // 2nd variant +// // %0 = memref.alloca() : memref<3x3xi32> +// // %1 = memref.alloca() : memref<3x3xi32> +// // %2 = memref.alloca() : memref<5x5xi32> +// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): +// %4 = arith.muli %in, %in_5 : i32 +// %5 = arith.addi %out_7, %4 : i32 +// %6 = arith.muli %in, %in_6 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7, %5 : i32, i32 +// } +// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): +// %4 = arith.muli %in, %in : i32 +// %5 = arith.muli %4, %in_6 : i32 +// %6 = arith.addi %out_8, %5 : i32 +// %7 = arith.muli %in_5, %in_5 : i32 +// %8 = arith.muli %7, %in_6 : i32 +// %9 = arith.addi %out_7, %8 : i32 +// %10 = arith.muli %in, %in_5 : i32 +// %11 = arith.muli %10, %in_6 : i32 +// %12 = arith.addi %out, %11 : i32 +// linalg.yield %12, %9, %6 : i32, i32, i32 +// } +// %3 = memref.get_global @score : memref<512x512xi32> +// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%3 : memref<512x512xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): +// %4 = arith.muli %in, %in_5 : i32 +// %5 = arith.muli %in_6, %in_6 : i32 +// %6 = arith.subi %4, %5 : i32 +// %7 = arith.addi %in, %in_5 : i32 +// %8 = arith.muli %7, %c4_i32 : i32 +// %9 = arith.muli %8, %7 : i32 +// %10 = arith.subi %6, %9 : i32 +// linalg.yield %10 : i32 +// } +// return %c0_i32 : i32 +// } +// } +struct LinalgDebufferization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::FuncOp funcOp, + PatternRewriter &rewriter) const final { + + auto module = funcOp->getParentOfType(); + + SmallVector opsToDelete; + llvm::SmallPtrSet opsToDeleteSet; + //Tracks both old linalg.generics and linalg.generics with repeated values in ins and outs + llvm::SmallPtrSet processedGenericOps; + + LogicalResult passResult = success(); + funcOp.walk([&](mlir::memref::AllocaOp allocaOp) -> WalkResult { + auto module = allocaOp->getParentOfType(); + rewriter.setInsertionPointAfter(allocaOp); + auto tensorType = RankedTensorType::get(allocaOp.getType().getShape(), allocaOp.getType().getElementType()); + + //Check to see if only linalg.generic are users of the allocaOp for now. + //TODO: Extend this + if(!llvm::all_of(allocaOp->getUsers(),[](Operation *op) { + return isa(op); + })){ + passResult = failure(); + return WalkResult::interrupt(); + } + + //auto emptyTensor = rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), allocaOp.getType().getElementType()); + auto toTensorOp = rewriter.create( + allocaOp.getLoc(), + tensorType, + allocaOp); + Value currentTensor = toTensorOp; + + //Check if allocaOp is an output in current genericOp + for (auto user : allocaOp->getUsers()) { + if (auto genericOp = dyn_cast(user)) { + + //auto genericOp = cast(user); + if(processedGenericOps.count(genericOp) > 0) + continue; + rewriter.setInsertionPointAfter(genericOp); + + SmallVector newInputs; + SmallVector newOutputs; + SmallVector resultTypes; + //Create a new linalg.generic in Destination Style Passing format + + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + for(auto input : genericOp.getInputs()){ + newInputs.push_back(input == allocaOp ? currentTensor : input); + } + + //ArrayRef resultTypes; + int newCurrentTensorIndex = -1; + int index = 0; + for(auto output : genericOp.getOutputs()){ + newOutputs.push_back(output == allocaOp ? currentTensor : output); + resultTypes.push_back(currentTensor.getType()); + if(output == allocaOp) { + newCurrentTensorIndex = index; + } + index++; + } + + StringAttr empty = StringAttr::get(genericOp.getContext()); + ArrayRef resultTypesRef(resultTypes); + auto newGenericOp = rewriter.create(genericOp.getLoc(), resultTypesRef, newInputs, newOutputs, + genericOp.getIndexingMaps(), genericOp.getIteratorTypes(), empty, empty); + + Region &opRegion = newGenericOp.getRegion(); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().end()); + + //Replace all uses of original generic op with the new one + int idxOldGeneric=0; + int idxNewGeneric=0; + for (unsigned i = 0; i < genericOp->getNumResults(); ++i) { + if(i == newCurrentTensorIndex) { + idxNewGeneric++; + } + genericOp->getResult(i).replaceAllUsesWith(newGenericOp->getResult(i)); + idxOldGeneric++; + idxNewGeneric++; + } + + //Delete the original genericOp + opsToDelete.push_back(genericOp.getOperation()); + if(newCurrentTensorIndex != -1) + currentTensor = newGenericOp.getResult(newCurrentTensorIndex); + + processedGenericOps.insert(genericOp.getOperation()); + } + } + + auto toMemrefOp = rewriter.create( + allocaOp.getLoc(), + allocaOp.getType(), + currentTensor); + rewriter.create(allocaOp.getLoc(), toMemrefOp, allocaOp); + //opsToDelete.push_back(allocaOp.getOperation()); + return WalkResult::advance(); + }); + for (Operation *op : opsToDelete) { + op->erase(); + } + opsToDelete.clear(); + + return passResult; + } +}; + +namespace { +struct LinalgDebufferize + : public LinalgDebufferizeBase { + void runOnOperation() override; +}; +} // namespace + +void LinalgDebufferize::runOnOperation() { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + +namespace mlir { +namespace polygeist { +std::unique_ptr createLinalgDebufferizePass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir From 234238166ceeecc928825696aa46c2a159d35253 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Mon, 13 Jan 2025 13:25:03 -0800 Subject: [PATCH 35/77] Added more complex case to show debufferization ; Fixed bugs in debufferization; Current debufferization works on all memref.alloca() --- debufferize.mlir | 116 +++++++++++++++------ lib/polygeist/Passes/LinalgDebufferize.cpp | 102 +++++++----------- 2 files changed, 123 insertions(+), 95 deletions(-) diff --git a/debufferize.mlir b/debufferize.mlir index 3e310644f4bc..96f278038f9f 100644 --- a/debufferize.mlir +++ b/debufferize.mlir @@ -4,36 +4,94 @@ #map17 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> #map18 = affine_map<(d0, d1, d2, d3) -> (d1, d0)> #map19 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +#map22 = affine_map<(d0, d1) -> (d1, d0)> - module @in_place_add{ - func.func @in_place_add(%value: f32) { - %c0 = arith.constant 0 : index - %buffer = memref.alloca() : memref<128xf32> - linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"] - } ins(%buffer : memref<128xf32>) - outs(%buffer : memref<128xf32>) { - ^bb0(%in: f32, %out: f32): - %sum = arith.addf %in, %value : f32 - linalg.yield %sum : f32 - } - return - } - } +module @in_place_add{ + func.func @in_place_add(%value: f32) { + %c0 = arith.constant 0 : index + %buffer = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } +} module @conv_2 { - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - %0 = memref.alloca() : memref<515x67xi32> - %1 = memref.alloca() : memref<4x4xi32> - %2 = memref.alloca() : memref<512x64xi32> - linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { - ^bb0(%in: i32, %in_0: i32, %out: i32): - %3 = arith.muli %in, %in_0 : i32 - %4 = arith.addi %out, %3 : i32 - linalg.yield %4 : i32 + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.alloca() : memref<515x67xi32> + %1 = memref.alloca() : memref<4x4xi32> + %2 = memref.alloca() : memref<512x64xi32> + linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %3 = arith.muli %in, %in_0 : i32 + %4 = arith.addi %out, %3 : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } +} + +module @harris_score_with_gradient_extra_kernel { + memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %score = memref.alloca() : memref<512x512xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + %2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + // 2nd variant + // %0 = memref.alloca() : memref<3x3xi32> + // %1 = memref.alloca() : memref<3x3xi32> + // %2 = memref.alloca() : memref<5x5xi32> + linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.addi %out_7, %4 : i32 + %6 = arith.muli %in, %in_6 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7, %5 : i32, i32 + } + linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): + %4 = arith.muli %in, %in : i32 + %5 = arith.muli %4, %in_6 : i32 + %6 = arith.addi %out_8, %5 : i32 + %7 = arith.muli %in_5, %in_5 : i32 + %8 = arith.muli %7, %in_6 : i32 + %9 = arith.addi %out_7, %8 : i32 + %10 = arith.muli %in, %in_5 : i32 + %11 = arith.muli %10, %in_6 : i32 + %12 = arith.addi %out, %11 : i32 + linalg.yield %12, %9, %6 : i32, i32, i32 + } + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.muli %in_6, %in_6 : i32 + %6 = arith.subi %4, %5 : i32 + %7 = arith.addi %in, %in_5 : i32 + %8 = arith.muli %7, %c4_i32 : i32 + %9 = arith.muli %8, %7 : i32 + %10 = arith.subi %6, %9 : i32 + linalg.yield %10 : i32 + } + return %c0_i32 : i32 } - return %c0_i32 : i32 - } -} \ No newline at end of file + } \ No newline at end of file diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index c5e04a67af5b..c7c55a465cdc 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -29,65 +29,32 @@ using namespace linalg; using namespace tensor; using namespace bufferization; +std::vector getSortedUsers(Operation *op) { + if(!op) return {}; + // Find the parent function + auto funcOp = op->getParentOfType(); + if (!funcOp) return {}; + + //Map to store order of operations + llvm::DenseMap opOrder; + size_t order = 0; + + funcOp.walk([&](Operation *curOp) { + opOrder[curOp] = order++; + }); + + std::vector sortedUsers(op->getUsers().begin(), op->getUsers().end()); + + std::sort(sortedUsers.begin(), sortedUsers.end(), + [&](Operation *a, Operation *b) { + return opOrder[a] < opOrder[b]; + } + ); + + return sortedUsers; +} -//module @harris_score_with_gradient_extra_kernel { -// memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> -// memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> -// memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> -// memref.global @score : memref<512x512xi32> = uninitialized -// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c4_i32 = arith.constant 4 : i32 -// %c0_i32 = arith.constant 0 : i32 -// %alloca = memref.alloca() : memref<512x512xi32> -// %alloca_0 = memref.alloca() : memref<512x512xi32> -// %alloca_1 = memref.alloca() : memref<512x512xi32> -// %alloca_2 = memref.alloca() : memref<516x516xi32> -// %alloca_3 = memref.alloca() : memref<516x516xi32> -// %alloca_4 = memref.alloca() : memref<518x518xi32> -// %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> -// %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> -// %2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> -// // 2nd variant -// // %0 = memref.alloca() : memref<3x3xi32> -// // %1 = memref.alloca() : memref<3x3xi32> -// // %2 = memref.alloca() : memref<5x5xi32> -// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): -// %4 = arith.muli %in, %in_5 : i32 -// %5 = arith.addi %out_7, %4 : i32 -// %6 = arith.muli %in, %in_6 : i32 -// %7 = arith.addi %out, %6 : i32 -// linalg.yield %7, %5 : i32, i32 -// } -// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): -// %4 = arith.muli %in, %in : i32 -// %5 = arith.muli %4, %in_6 : i32 -// %6 = arith.addi %out_8, %5 : i32 -// %7 = arith.muli %in_5, %in_5 : i32 -// %8 = arith.muli %7, %in_6 : i32 -// %9 = arith.addi %out_7, %8 : i32 -// %10 = arith.muli %in, %in_5 : i32 -// %11 = arith.muli %10, %in_6 : i32 -// %12 = arith.addi %out, %11 : i32 -// linalg.yield %12, %9, %6 : i32, i32, i32 -// } -// %3 = memref.get_global @score : memref<512x512xi32> -// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%3 : memref<512x512xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): -// %4 = arith.muli %in, %in_5 : i32 -// %5 = arith.muli %in_6, %in_6 : i32 -// %6 = arith.subi %4, %5 : i32 -// %7 = arith.addi %in, %in_5 : i32 -// %8 = arith.muli %7, %c4_i32 : i32 -// %9 = arith.muli %8, %7 : i32 -// %10 = arith.subi %6, %9 : i32 -// linalg.yield %10 : i32 -// } -// return %c0_i32 : i32 -// } -// } struct LinalgDebufferization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -123,8 +90,10 @@ struct LinalgDebufferization : public OpRewritePattern { allocaOp); Value currentTensor = toTensorOp; + auto sortedUsers = getSortedUsers(allocaOp); + //Check if allocaOp is an output in current genericOp - for (auto user : allocaOp->getUsers()) { + for (auto user : sortedUsers) { if (auto genericOp = dyn_cast(user)) { //auto genericOp = cast(user); @@ -147,7 +116,7 @@ struct LinalgDebufferization : public OpRewritePattern { int index = 0; for(auto output : genericOp.getOutputs()){ newOutputs.push_back(output == allocaOp ? currentTensor : output); - resultTypes.push_back(currentTensor.getType()); + resultTypes.push_back(output == allocaOp ? currentTensor.getType() : output.getType()); if(output == allocaOp) { newCurrentTensorIndex = index; } @@ -163,15 +132,16 @@ struct LinalgDebufferization : public OpRewritePattern { rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().end()); //Replace all uses of original generic op with the new one - int idxOldGeneric=0; - int idxNewGeneric=0; + //int idxOldGeneric=0; + //int idxNewGeneric=0; for (unsigned i = 0; i < genericOp->getNumResults(); ++i) { - if(i == newCurrentTensorIndex) { - idxNewGeneric++; - } + //if(i == newCurrentTensorIndex) { + // idxNewGeneric++; + //} + //genericOp->getResult(idxOldGeneric).replaceAllUsesWith(newGenericOp->getResult(idxNewGeneric)); + //idxOldGeneric++; + //idxNewGeneric++; genericOp->getResult(i).replaceAllUsesWith(newGenericOp->getResult(i)); - idxOldGeneric++; - idxNewGeneric++; } //Delete the original genericOp From fde88fe53360f4919174fb10c6132725606b1219 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Mon, 13 Jan 2025 13:29:39 -0800 Subject: [PATCH 36/77] Fixed clang format --- lib/polygeist/Ops.cpp | 426 ++++++++++++--------- lib/polygeist/Passes/LinalgDebufferize.cpp | 214 ++++++----- lib/polygeist/Passes/RaiseToLinalg.cpp | 238 +++++++----- lib/polygeist/Passes/RemoveIterArgs.cpp | 191 ++++----- tools/polygeist-opt/polygeist-opt.cpp | 2 +- 5 files changed, 583 insertions(+), 488 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 4010e58330cb..c65e5a9d3afb 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -23,10 +23,10 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/AffineMap.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" @@ -5738,10 +5738,12 @@ struct SubMapOpCanonicalize : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SubmapOp op, PatternRewriter &rewriter) const override { - /// if submap %x is identity map and has the same size as the static size of %x + /// if submap %x is identity map and has the same size as the static size of + /// %x ///. replace submap with memref.cast of memref<4x5xf32> to memref /// %x = ... : memref<4x5xf32> - // %y = polygeist.submap %x(#identity_map, %constant_4, %constant_5) : memref<4x5xf32> -> memref + // %y = polygeist.submap %x(#identity_map, %constant_4, %constant_5) : + // memref<4x5xf32> -> memref // //. becomes // @@ -5750,19 +5752,23 @@ struct SubMapOpCanonicalize : public OpRewritePattern { // auto source_memref = op.getMemref(); bool isIdentity = op.getMap().isIdentity(); - bool isInputSameDim = llvm::all_of(llvm::zip_equal(op.getSizes(), cast(source_memref.getType()).getShape()), [&](auto pair) { - if (std::get<1>(pair) == -1) - return false; - APInt matched; - if (matchPattern(std::get<0>(pair), m_ConstantInt(&matched))) { - return std::get<1>(pair) == matched; - } - return false; - }); + bool isInputSameDim = llvm::all_of( + llvm::zip_equal(op.getSizes(), + cast(source_memref.getType()).getShape()), + [&](auto pair) { + if (std::get<1>(pair) == -1) + return false; + APInt matched; + if (matchPattern(std::get<0>(pair), m_ConstantInt(&matched))) { + return std::get<1>(pair) == matched; + } + return false; + }); if (isIdentity && isInputSameDim) { - rewriter.replaceOpWithNewOp(op, op.getType(), op.getMemref()); + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getMemref()); return success(); - } + } if (auto sapOp = source_memref.getDefiningOp()) { auto load_map = op.getMap(); auto submap_map = sapOp.getMap(); @@ -5771,141 +5777,147 @@ struct SubMapOpCanonicalize : public OpRewritePattern { operands.append(op.getSymbols().begin(), op.getSymbols().end()); operands.append(op.getSymbols().begin(), op.getSymbols().end()); operands.append(op.getSizes().begin(), op.getSizes().end()); - rewriter.replaceOpWithNewOp(op, op.getType(), sapOp.getMemref(), operands, new_map); + rewriter.replaceOpWithNewOp( + op, op.getType(), sapOp.getMemref(), operands, new_map); return success(); } return failure(); } }; - - struct LinalgOfSubmap : public OpRewritePattern { +struct LinalgOfSubmap : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { - //Check body content - auto module = genericOp->getParentOfType(); - Region &genericBody = genericOp.getRegion(); - Block &entryBlock = genericBody.front(); - ValueRange blockArgs = entryBlock.getArguments(); - auto inputs = genericOp.getInputs(); - auto outputs = genericOp.getOutputs(); - SmallVector listOfAllocas; - SmallVector listOfNewMaps; - SmallVector listOfNewInputs, listOfNewOutputs; - //auto mapAttrsArr = genericOp.getIndexingMaps(); - //for(auto mapAttr: mapAttrsArr) { - // AffineMap map = mapAttr.cast().getValue(); - // if(map == convMap[0] && !mapped[0]) { - // } - //} - for(auto inp: inputs) { - if(auto blkArg = dyn_cast(inp)) { - listOfNewInputs.push_back(inp); - } - else if(auto subMap = dyn_cast(inp.getDefiningOp())) { - auto source_memref = subMap.getMemref(); - //if (auto blockArg = dyn_cast_or_null(op)) { - //if(auto source_alloca = dyn_cast(source_memref.getDefiningOp())) - //{ - auto map = subMap.getMap(); - listOfNewMaps.push_back(map); - listOfNewInputs.push_back(source_memref); - //} - //else { - // assert(false && "Only expect allocaOp as source for submap canonicalization right now"); - // return failure(); - //} - } - else { - listOfNewInputs.push_back(inp); - } - } - - for(auto out: outputs) { - if(auto blkArg = dyn_cast(out)) { - listOfNewOutputs.push_back(out); - } - else if(auto subMap = dyn_cast(out.getDefiningOp())) { - auto source_memref = subMap.getMemref(); - auto map = subMap.getMap(); - listOfNewMaps.push_back(map); - listOfNewOutputs.push_back(source_memref); - } - else { - listOfNewOutputs.push_back(out); - } + // Check body content + auto module = genericOp->getParentOfType(); + Region &genericBody = genericOp.getRegion(); + Block &entryBlock = genericBody.front(); + ValueRange blockArgs = entryBlock.getArguments(); + auto inputs = genericOp.getInputs(); + auto outputs = genericOp.getOutputs(); + SmallVector listOfAllocas; + SmallVector listOfNewMaps; + SmallVector listOfNewInputs, listOfNewOutputs; + // auto mapAttrsArr = genericOp.getIndexingMaps(); + // for(auto mapAttr: mapAttrsArr) { + // AffineMap map = mapAttr.cast().getValue(); + // if(map == convMap[0] && !mapped[0]) { + // } + // } + for (auto inp : inputs) { + if (auto blkArg = dyn_cast(inp)) { + listOfNewInputs.push_back(inp); + } else if (auto subMap = + dyn_cast(inp.getDefiningOp())) { + auto source_memref = subMap.getMemref(); + // if (auto blockArg = dyn_cast_or_null(op)) { + // if(auto source_alloca = + // dyn_cast(source_memref.getDefiningOp())) + //{ + auto map = subMap.getMap(); + listOfNewMaps.push_back(map); + listOfNewInputs.push_back(source_memref); + //} + // else { + // assert(false && "Only expect allocaOp as source for submap + // canonicalization right now"); return failure(); + //} + } else { + listOfNewInputs.push_back(inp); } - ArrayRef maps(listOfNewMaps); - //No submap ops detected - if(maps.size() == 0) - return failure(); - //If inverse permutation exists, then we can canonicalize the linalg of submap to linalg - //TODO: Fails for: - // 1. Maps with symbols - // 2. Maps which are not resolvable 1 to 1 with memref for all dims - if(inversePermutation(concatAffineMaps(maps))) { - StringAttr empty = StringAttr::get(genericOp.getContext()); - auto newGenericOp = rewriter.create(genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, listOfNewMaps, genericOp.getIteratorTypesArray(), - empty, empty); - rewriter.inlineRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().end()); - - //auto &block = newGenericOp.getRegion().front(); - //block.addArguments(newGenericOp.getOperandTypes(), SmallVector(newGenericOp.getNumOperands(), genericOp.getLoc())); - - rewriter.replaceOp(genericOp, newGenericOp.getResults()); - return success(); + } + + for (auto out : outputs) { + if (auto blkArg = dyn_cast(out)) { + listOfNewOutputs.push_back(out); + } else if (auto subMap = + dyn_cast(out.getDefiningOp())) { + auto source_memref = subMap.getMemref(); + auto map = subMap.getMap(); + listOfNewMaps.push_back(map); + listOfNewOutputs.push_back(source_memref); + } else { + listOfNewOutputs.push_back(out); } - //for(iterate over inputs) - //{ - // gather maps - // gather submaps - // Gather affine maps from submaps - // Check over 2 iterations if all the indexes can be solved. - // Use the same logic as linalg.generic to do this. - // if success in getting vars - // replace affine map from submap to linalg.generic - // replace input memref as direct input to linalg.generic - //} - //assert(false && "inversePermutation doesn't exists for the given linalg generic"); + } + ArrayRef maps(listOfNewMaps); + // No submap ops detected + if (maps.size() == 0) return failure(); + // If inverse permutation exists, then we can canonicalize the linalg of + // submap to linalg + // TODO: Fails for: + // 1. Maps with symbols + // 2. Maps which are not resolvable 1 to 1 with memref for all dims + if (inversePermutation(concatAffineMaps(maps))) { + StringAttr empty = StringAttr::get(genericOp.getContext()); + auto newGenericOp = rewriter.create( + genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, + listOfNewMaps, genericOp.getIteratorTypesArray(), empty, empty); + rewriter.inlineRegionBefore(genericOp.getRegion(), + newGenericOp.getRegion(), + newGenericOp.getRegion().end()); + + // auto &block = newGenericOp.getRegion().front(); + // block.addArguments(newGenericOp.getOperandTypes(), + // SmallVector(newGenericOp.getNumOperands(), + // genericOp.getLoc())); + + rewriter.replaceOp(genericOp, newGenericOp.getResults()); + return success(); } - }; + // for(iterate over inputs) + //{ + // gather maps + // gather submaps + // Gather affine maps from submaps + // Check over 2 iterations if all the indexes can be solved. + // Use the same logic as linalg.generic to do this. + // if success in getting vars + // replace affine map from submap to linalg.generic + // replace input memref as direct input to linalg.generic + // } + // assert(false && "inversePermutation doesn't exists for the given linalg + // generic"); + return failure(); + } +}; // struct LinalgOfSubmap : public OpRewritePattern { // using OpRewritePattern::OpRewritePattern; // LogicalResult matchAndRewrite(linalg::GenericOp gen, // PatternRewriter &rewriter) const override { -// // Canonicalization 1 linalg.generic of map of submap. -> linalg.generic of map of submap +// // Canonicalization 1 linalg.generic of map of submap. -> linalg.generic +// of map of submap // //. iff the submap's affine map != identity // //. replace inner affine map with composition - -// // Canonicalizeation 3: submap which only sets bounds, of an input memref with the same bounds -> noop / cast - +// // Canonicalizeation 3: submap which only sets bounds, of an input memref +// with the same bounds -> noop / cast // // Canonicalization 1.5 (mix of 1/2) // //. linalg_map = identity a[i,j,x,y] -> u[i+x][j+y] -// //. linalg_map = [i,j,x,y]->(i+x,j+y) a[i,j] -> u[i,j]. # but still keeping the upper loop limit +// //. linalg_map = [i,j,x,y]->(i+x,j+y) a[i,j] -> u[i,j]. # but still +// keeping the upper loop limit // //. 1 - // // a[i] -> x[] // // a[1] -> x[] // // a[2] -> x[] - // // a[i,j] = x[map(i,j)]. ; the subbmap op -// //a[i+x][j+y] : submap defines iteration var 0 goes from 0 ... A0. and var 1 goes from 0 ... A1 +// //a[i+x][j+y] : submap defines iteration var 0 goes from 0 ... A0. and +// var 1 goes from 0 ... A1 // //b[x][y] // //c[i+x][j+y] // // here we have 4 iteration variables that linalg is doing i, j, x, y // // for (i : ...) // //. for (j : ...) // //. for (x : ...) -// //. for (y : ...) +// //. for (y : ...) // // c[i+x][j+y] += a[i+x][j+y] * b[x][y] // // a[i+x][j+y] @@ -5913,35 +5925,36 @@ struct SubMapOpCanonicalize : public OpRewritePattern { // // for (i : ...) // //. for (j : ...) // //. for (x : ...) -// //. for (y : ...) +// //. for (y : ...) // // c[i+x][j+y] += a[i+x][j+y] - -// //x[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// //x[map(i+x,j+y)] pass in the outermost one with correspondidng composed +// maps // //b[x][y] // //c[i+x][j+y] - -// // requirement here, is that all linalg.generic loop bounds must be solvable after replacement +// // requirement here, is that all linalg.generic loop bounds must be +// solvable after replacement // // for example, this would not be permissible // // a[i] -> x[]. ; a = submap memref -> memref<100xf32> -// // out[] +// // out[] -// // This cannot be replaced since now the linalg generic iteration variable i cannot be solved for +// // This cannot be replaced since now the linalg generic iteration variable +// i cannot be solved for - - // for (auto &&[op, opmap] : gen.getInputsAndMaps()) { // if (auto submap = op.getDefiningOp()) { // bool solvable = false; // /// Cannoicalization 2: index removal // //. x[i, j] -> v[i]. can we get rid of j? -// //. Are input indices defined by other ops, and if so, can we simplify +// //. Are input indices defined by other ops, and if so, can we +// simplify // //. 1) Take all other input memrefs // // 2) Determine all solvable indices from those input memrefs // //. For each index which is solvable from 2) -// // if it can either be removed from the submap, or combined with another index in the submap, +// // if it can either be removed from the submap, or combined +// with another index in the submap, // // remove it from the submap // SmallVector exprs; @@ -5963,53 +5976,64 @@ struct SubMapOpCanonicalize : public OpRewritePattern { // linalg.determineSolvableIndices(solvable, exprs); // SmallSet notsolvable = allvariables - solvable; - -// //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps + +// //a[map(i+x,j+y)] pass in the outermost one with correspondidng +// composed maps // //b[x][y] // //c[i+x][j+y] // // Supose we're solving for a -// // Here exprs would contain all the affineexprs from b and c. (aka inputs - {x}) - +// // Here exprs would contain all the affineexprs from b and c. (aka +// inputs - {x}) + // // {x, y, i+x, j+y} -// // Running a solver allows us to uniquely solve for all of, x, y, i, and j with these expressoin +// // Running a solver allows us to uniquely solve for all of, x, y, i, +// and j with these expressoin // // In this case we can attempt to remove dependence on x, y, i, j -// // If however we had -// //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// // If however we had +// //a[map(i+x,j+y)] pass in the outermost one with correspondidng +// composed maps // //b[x][y] // //c[i+x][y] // // we would solve with {x, y, i+x, y} -// // Running a solver we would be able to sole for {x, y, i} but not solve for j -// // In this case we can attempt to remove dependence on x, y, i, but not on j +// // Running a solver we would be able to sole for {x, y, i} but not +// solve for j +// // In this case we can attempt to remove dependence on x, y, i, but +// not on j -// // let's take easiest one where a is just broadcasting a constant to all input indices +// // let's take easiest one where a is just broadcasting a constant to +// all input indices // // a = submap (m,n) -> u[] -// // a[i+x, j+y] For all input indices which are uniquely solvable, here that is both +// // a[i+x, j+y] For all input indices which are uniquely solvable, here +// that is both // //. index 0 = i + x // //. and index 1 = j + y // // set the input map to compose with the submap's affine map - // /// Easy special case // if (notsolvable.size() == 0) { - -// replace opmap with submap.compose(opmap) taking into account the the ConstantIntRanges +// replace opmap with submap.compose(opmap) taking into account the the +// ConstantIntRanges // // Easy case // } // // We now have two maps with different meanings // // Let |N| be the number of loop variables in the linalg.generic // // Let |M| be length(submap.getType().getShape()) -// // Let |Q| be length(submap.getInput().getType().getShape()), number of dimensions of input operand to the submap - -// // opmap from the linalg.generic which takes linalg.generic loop indices |N| -> inputs to the submap op. |M| +// // Let |Q| be length(submap.getInput().getType().getShape()), number +// of dimensions of input operand to the submap + +// // opmap from the linalg.generic which takes linalg.generic loop +// indices |N| -> inputs to the submap op. |M| + +// // submap.map. submap op. which takes input indices |M|. +// -> indices for the corresponing base memref |Q| -// // submap.map. submap op. which takes input indices |M|. -> indices for the corresponing base memref |Q| - // // Example -// //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// //a[map(i+x,j+y)] pass in the outermost one with correspondidng +// composed maps // //b[x][y] // //c[i+x][j+y] @@ -6036,10 +6060,13 @@ struct SubMapOpCanonicalize : public OpRewritePattern { // for (auto var : notsolvable) { // if (linalgexpr.isFunctionOf(var)) { // legal = false; -// // we can pop this from the not solvable since now this index will define +// // we can pop this from the not solvable since now this index +// will define // // the value of var for future iterations. -// // But doing so requires proving it is not a linear combination of previously -// // visited linalgexpr's, so we'll defer this for a later optimization +// // But doing so requires proving it is not a linear +// combination of previously +// // visited linalgexpr's, so we'll defer this for a later +// optimization // // notsolvable.pop(var); // } // } @@ -6050,53 +6077,67 @@ struct SubMapOpCanonicalize : public OpRewritePattern { // // The non-special case version // // j is not solvable -// //a[map(i+x,j+y)] pass in the outermost one with correspondidng composed maps +// //a[map(i+x,j+y)] pass in the outermost one with correspondidng +// composed maps // //b[x][y] // //c[i+x][y] -// // because j is not solvable we cannot move any expressions depending on j (in this case p depends on j) -// //. and the underlying sub expressions depending j, in this case via p are: +// // because j is not solvable we cannot move any expressions depending +// on j (in this case p depends on j) +// //. and the underlying sub expressions depending j, in this case via +// p are: // // a[1] = w + 4 and a[2] = w + 7 // // define a(w,p) -> u[c + 2 * p, w + 4, w + 7] // // with the general case optimization v0. [just moving expressions up] -// //a2[map(i+x, j+y), i + x + 4, i + x + 7] pass in the outermost one with correspondidng composed maps +// //a2[map(i+x, j+y), i + x + 4, i + x + 7] pass in the outermost one +// with correspondidng composed maps // //b[x][y] // //c[i+x][y] // // define a2(w, p) -> u[c + 2 * p] -// // with the general case optimization v1. [just eliminating unnecessary indices] +// // with the general case optimization v1. [just eliminating +// unnecessary indices] -// //a2[map(j+y), i + x + 4, i + x + 7] pass in the outermost one with correspondidng composed maps +// //a2[map(j+y), i + x + 4, i + x + 7] pass in the outermost one with +// correspondidng composed maps // //b[x][y] // //c[i+x][y] // // define a2(p) -> u[c + 2 * p] -// // So this optimization generally moves expression from the submap into the linalg map +// // So this optimization generally moves expression from the submap +// into the linalg map // // and it it also removes unnecessary indices into the submap - -// // If the entire submap is legal to inline, the solution is simple, replace the linalg -// // map with itself composed with the submap, and replace the original submap with the identity op -// if (legalIndices.size() == opmap.getExprs().size()) { -// // Note, it isn't 100% as simple as below since we still need to retain any constant op's in the -// // new submap op below, since linalg.generic doesn't support constant value's for the indexing, as far -// // as I (wmoses) know? +// // If the entire submap is legal to inline, the solution is simple, +// replace the linalg +// // map with itself composed with the submap, and replace the original +// submap with the identity op if (legalIndices.size() == +// opmap.getExprs().size()) { +// // Note, it isn't 100% as simple as below since we still need to +// retain any constant op's in the +// // new submap op below, since linalg.generic doesn't support +// constant value's for the indexing, as far +// // as I (wmoses) know? // newLinalgExprs = opmap.compose(submap.getMap()).getExprs(); -// newSubmapExprs = Affine::getIdentityMap(submap.getOperand().getShape().size()).getExprs(); +// newSubmapExprs = +// Affine::getIdentityMap(submap.getOperand().getShape().size()).getExprs(); // } else { // SmallVector illegalIndices = allIndices - legalIndices; -// // We can alternatively re-index maps which are solely functions of legal indices. -// for (auto &&[i, submapexpr] : llvm::enumerate(submap.getAffineMap().getExprs())) { +// // We can alternatively re-index maps which are solely functions of +// legal indices. for (auto &&[i, submapexpr] : +// llvm::enumerate(submap.getAffineMap().getExprs())) { // if (submapexpr is a function of any illegal indicies) { -// // we need to keep this as a submap expr (though re-indexed on the new number of exprs) +// // we need to keep this as a submap expr (though re-indexed on +// the new number of exprs) // newSubmapExprs.push_back(submapexpr.reindex()); // } else { -// // this index can be completely solved for with other inputs, let's move the expression from +// // this index can be completely solved for with other inputs, +// let's move the expression from // // a submap expression into a linalg.generic map expression. // newLinalgExprs.push_back(opmap.compose(submapexpr)); // newSubmapExprs.push_back(Affine::getIdentity()); @@ -6105,26 +6146,23 @@ struct SubMapOpCanonicalize : public OpRewritePattern { // } // if (solvable) { -// // replace the input to the generic with the input to the submap, and the new map -// return success(); +// // replace the input to the generic with the input to the submap, +// and the new map return success(); // } // } // } - - // for (auto op : gen.getOutputs()) { // if (auto submap = op.getDefiningOp()) { // bool solvable = false; // if (solvable) { // do the thing -// // replace the input to the generic with the input to the submap, and the new map -// return success(); +// // replace the input to the generic with the input to the submap, +// and the new map return success(); // } // } // } - // return failure(); // } // }; @@ -6284,28 +6322,31 @@ class LoadSubMap final : public OpRewritePattern { LogicalResult matchAndRewrite(affine::AffineLoadOp op, PatternRewriter &rewriter) const override { auto subMapOp = op.getMemRef().getDefiningOp(); - if (!subMapOp) return failure(); + if (!subMapOp) + return failure(); auto submap_map = subMapOp.getMap(); auto submap_operands = subMapOp.getSymbols(); auto source_memref = subMapOp.getMemref(); - + auto load_map = op.getAffineMap(); auto load_operands = op.getMapOperands(); auto new_map = submap_map.compose(load_map); SmallVector operands; - operands.append(load_operands.begin(), load_operands.begin() + load_map.getNumDims()); + operands.append(load_operands.begin(), + load_operands.begin() + load_map.getNumDims()); operands.append(submap_operands.begin(), submap_operands.end()); - operands.append(load_operands.begin() + load_map.getNumDims(), load_operands.end()); + operands.append(load_operands.begin() + load_map.getNumDims(), + load_operands.end()); - rewriter.replaceOpWithNewOp(op, source_memref, new_map, operands); + rewriter.replaceOpWithNewOp(op, source_memref, + new_map, operands); return success(); } }; - class StoreSubMap final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -6313,34 +6354,38 @@ class StoreSubMap final : public OpRewritePattern { LogicalResult matchAndRewrite(affine::AffineStoreOp op, PatternRewriter &rewriter) const override { auto subMapOp = op.getMemRef().getDefiningOp(); - if (!subMapOp) return failure(); + if (!subMapOp) + return failure(); auto submap_map = subMapOp.getMap(); auto submap_operands = subMapOp.getSymbols(); auto source_memref = subMapOp.getMemref(); - + auto load_map = op.getAffineMap(); auto load_operands = op.getMapOperands(); auto new_map = submap_map.compose(load_map); SmallVector operands; - operands.append(load_operands.begin(), load_operands.begin() + load_map.getNumDims()); + operands.append(load_operands.begin(), + load_operands.begin() + load_map.getNumDims()); operands.append(submap_operands.begin(), submap_operands.end()); - operands.append(load_operands.begin() + load_map.getNumDims(), load_operands.end()); + operands.append(load_operands.begin() + load_map.getNumDims(), + load_operands.end()); - rewriter.replaceOpWithNewOp(op, op.getValue(), source_memref, new_map, operands); + rewriter.replaceOpWithNewOp( + op, op.getValue(), source_memref, new_map, operands); return success(); } }; -OpFoldResult mlir::polygeist::SubmapOp::fold(mlir::polygeist::SubmapOp::FoldAdaptor adaptor) { +OpFoldResult mlir::polygeist::SubmapOp::fold( + mlir::polygeist::SubmapOp::FoldAdaptor adaptor) { // TODO if submap is identity return nothing // if submap of submap return new submap return nullptr; } - class DimSubMap final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -6348,10 +6393,12 @@ class DimSubMap final : public OpRewritePattern { LogicalResult matchAndRewrite(memref::DimOp op, PatternRewriter &rewriter) const override { auto subMapOp = op.getSource().getDefiningOp(); - if (!subMapOp) return failure(); + if (!subMapOp) + return failure(); auto idx = op.getIndex().getDefiningOp(); - if (!idx) return failure(); + if (!idx) + return failure(); rewriter.replaceOp(op, subMapOp.getSizes()[idx.value()]); @@ -6359,9 +6406,10 @@ class DimSubMap final : public OpRewritePattern { } }; -void polygeist::SubmapOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - //results.insert(context); +void polygeist::SubmapOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + // results.insert(context); results.insert(context); - //results.insert(context); + // results.insert(context); } diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index c7c55a465cdc..82310052c509 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -2,14 +2,14 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" @@ -30,27 +30,26 @@ using namespace tensor; using namespace bufferization; std::vector getSortedUsers(Operation *op) { - if(!op) return {}; + if (!op) + return {}; // Find the parent function auto funcOp = op->getParentOfType(); - if (!funcOp) return {}; + if (!funcOp) + return {}; - //Map to store order of operations + // Map to store order of operations llvm::DenseMap opOrder; size_t order = 0; - funcOp.walk([&](Operation *curOp) { - opOrder[curOp] = order++; - }); + funcOp.walk([&](Operation *curOp) { opOrder[curOp] = order++; }); - std::vector sortedUsers(op->getUsers().begin(), op->getUsers().end()); + std::vector sortedUsers(op->getUsers().begin(), + op->getUsers().end()); - std::sort(sortedUsers.begin(), sortedUsers.end(), - [&](Operation *a, Operation *b) { - return opOrder[a] < opOrder[b]; - } - ); + std::sort( + sortedUsers.begin(), sortedUsers.end(), + [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; }); return sortedUsers; } @@ -60,106 +59,112 @@ struct LinalgDebufferization : public OpRewritePattern { LogicalResult matchAndRewrite(func::FuncOp funcOp, PatternRewriter &rewriter) const final { - + auto module = funcOp->getParentOfType(); - SmallVector opsToDelete; - llvm::SmallPtrSet opsToDeleteSet; - //Tracks both old linalg.generics and linalg.generics with repeated values in ins and outs - llvm::SmallPtrSet processedGenericOps; + SmallVector opsToDelete; + llvm::SmallPtrSet opsToDeleteSet; + // Tracks both old linalg.generics and linalg.generics with repeated values + // in ins and outs + llvm::SmallPtrSet processedGenericOps; LogicalResult passResult = success(); funcOp.walk([&](mlir::memref::AllocaOp allocaOp) -> WalkResult { - auto module = allocaOp->getParentOfType(); - rewriter.setInsertionPointAfter(allocaOp); - auto tensorType = RankedTensorType::get(allocaOp.getType().getShape(), allocaOp.getType().getElementType()); - - //Check to see if only linalg.generic are users of the allocaOp for now. - //TODO: Extend this - if(!llvm::all_of(allocaOp->getUsers(),[](Operation *op) { - return isa(op); - })){ - passResult = failure(); - return WalkResult::interrupt(); - } - - //auto emptyTensor = rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), allocaOp.getType().getElementType()); - auto toTensorOp = rewriter.create( - allocaOp.getLoc(), - tensorType, - allocaOp); - Value currentTensor = toTensorOp; - - auto sortedUsers = getSortedUsers(allocaOp); - - //Check if allocaOp is an output in current genericOp - for (auto user : sortedUsers) { - if (auto genericOp = dyn_cast(user)) { - - //auto genericOp = cast(user); - if(processedGenericOps.count(genericOp) > 0) - continue; - rewriter.setInsertionPointAfter(genericOp); - - SmallVector newInputs; - SmallVector newOutputs; - SmallVector resultTypes; - //Create a new linalg.generic in Destination Style Passing format - - ArrayAttr indexingMaps = genericOp.getIndexingMaps(); - for(auto input : genericOp.getInputs()){ - newInputs.push_back(input == allocaOp ? currentTensor : input); - } + auto module = allocaOp->getParentOfType(); + rewriter.setInsertionPointAfter(allocaOp); + auto tensorType = RankedTensorType::get( + allocaOp.getType().getShape(), allocaOp.getType().getElementType()); + + // Check to see if only linalg.generic are users of the allocaOp for now. + // TODO: Extend this + if (!llvm::all_of(allocaOp->getUsers(), [](Operation *op) { + return isa(op); + })) { + passResult = failure(); + return WalkResult::interrupt(); + } + + // auto emptyTensor = + // rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), + // allocaOp.getType().getElementType()); + auto toTensorOp = rewriter.create( + allocaOp.getLoc(), tensorType, allocaOp); + Value currentTensor = toTensorOp; + + auto sortedUsers = getSortedUsers(allocaOp); + + // Check if allocaOp is an output in current genericOp + for (auto user : sortedUsers) { + if (auto genericOp = dyn_cast(user)) { + + // auto genericOp = cast(user); + if (processedGenericOps.count(genericOp) > 0) + continue; + rewriter.setInsertionPointAfter(genericOp); + + SmallVector newInputs; + SmallVector newOutputs; + SmallVector resultTypes; + // Create a new linalg.generic in Destination Style Passing format + + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + for (auto input : genericOp.getInputs()) { + newInputs.push_back(input == allocaOp ? currentTensor : input); + } - //ArrayRef resultTypes; - int newCurrentTensorIndex = -1; - int index = 0; - for(auto output : genericOp.getOutputs()){ - newOutputs.push_back(output == allocaOp ? currentTensor : output); - resultTypes.push_back(output == allocaOp ? currentTensor.getType() : output.getType()); - if(output == allocaOp) { - newCurrentTensorIndex = index; - } - index++; + // ArrayRef resultTypes; + int newCurrentTensorIndex = -1; + int index = 0; + for (auto output : genericOp.getOutputs()) { + newOutputs.push_back(output == allocaOp ? currentTensor : output); + resultTypes.push_back(output == allocaOp ? currentTensor.getType() + : output.getType()); + if (output == allocaOp) { + newCurrentTensorIndex = index; } + index++; + } - StringAttr empty = StringAttr::get(genericOp.getContext()); - ArrayRef resultTypesRef(resultTypes); - auto newGenericOp = rewriter.create(genericOp.getLoc(), resultTypesRef, newInputs, newOutputs, - genericOp.getIndexingMaps(), genericOp.getIteratorTypes(), empty, empty); - - Region &opRegion = newGenericOp.getRegion(); - rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().end()); - - //Replace all uses of original generic op with the new one - //int idxOldGeneric=0; - //int idxNewGeneric=0; - for (unsigned i = 0; i < genericOp->getNumResults(); ++i) { - //if(i == newCurrentTensorIndex) { - // idxNewGeneric++; - //} - //genericOp->getResult(idxOldGeneric).replaceAllUsesWith(newGenericOp->getResult(idxNewGeneric)); - //idxOldGeneric++; - //idxNewGeneric++; - genericOp->getResult(i).replaceAllUsesWith(newGenericOp->getResult(i)); - } - - //Delete the original genericOp - opsToDelete.push_back(genericOp.getOperation()); - if(newCurrentTensorIndex != -1) - currentTensor = newGenericOp.getResult(newCurrentTensorIndex); - - processedGenericOps.insert(genericOp.getOperation()); + StringAttr empty = StringAttr::get(genericOp.getContext()); + ArrayRef resultTypesRef(resultTypes); + auto newGenericOp = rewriter.create( + genericOp.getLoc(), resultTypesRef, newInputs, newOutputs, + genericOp.getIndexingMaps(), genericOp.getIteratorTypes(), empty, + empty); + + Region &opRegion = newGenericOp.getRegion(); + rewriter.cloneRegionBefore(genericOp.getRegion(), + newGenericOp.getRegion(), + newGenericOp.getRegion().end()); + + // Replace all uses of original generic op with the new one + // int idxOldGeneric=0; + // int idxNewGeneric=0; + for (unsigned i = 0; i < genericOp->getNumResults(); ++i) { + // if(i == newCurrentTensorIndex) { + // idxNewGeneric++; + // } + // genericOp->getResult(idxOldGeneric).replaceAllUsesWith(newGenericOp->getResult(idxNewGeneric)); + // idxOldGeneric++; + // idxNewGeneric++; + genericOp->getResult(i).replaceAllUsesWith( + newGenericOp->getResult(i)); } + + // Delete the original genericOp + opsToDelete.push_back(genericOp.getOperation()); + if (newCurrentTensorIndex != -1) + currentTensor = newGenericOp.getResult(newCurrentTensorIndex); + + processedGenericOps.insert(genericOp.getOperation()); } + } - auto toMemrefOp = rewriter.create( - allocaOp.getLoc(), - allocaOp.getType(), - currentTensor); - rewriter.create(allocaOp.getLoc(), toMemrefOp, allocaOp); - //opsToDelete.push_back(allocaOp.getOperation()); - return WalkResult::advance(); + auto toMemrefOp = rewriter.create( + allocaOp.getLoc(), allocaOp.getType(), currentTensor); + rewriter.create(allocaOp.getLoc(), toMemrefOp, allocaOp); + // opsToDelete.push_back(allocaOp.getOperation()); + return WalkResult::advance(); }); for (Operation *op : opsToDelete) { op->erase(); @@ -171,8 +176,7 @@ struct LinalgDebufferization : public OpRewritePattern { }; namespace { -struct LinalgDebufferize - : public LinalgDebufferizeBase { +struct LinalgDebufferize : public LinalgDebufferizeBase { void runOnOperation() override; }; } // namespace diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 46021a556717..3d99e6f67029 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -101,21 +101,25 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { } // Given an affine map `oldmap`, memref `val`, and corresponding input values -// (which are a list of indicies, then symbols), and a set of loop indices `indices` produce -// the following: +// (which are a list of indicies, then symbols), and a set of loop indices +// `indices` produce the following: // 1. A (potentially new) memref value `newval` which does not have any // dependence on `indices` // and -// 2. an affine map `newmap` which takes size(indices) values (`indices`) and produces -// indices into `newval` such that +// 2. an affine map `newmap` which takes size(indices) values (`indices`) and +// produces indices into `newval` such that // indexing `newval[map(indices)]` produces the same result as indexing the // original map. -// check_reduction is set true, when passed from store/linalg.generic's output variable. -// And it is returned true, only if index was not encountered in oldmap operands and check_reduction was set true. +// check_reduction is set true, when passed from store/linalg.generic's output +// variable. And it is returned true, only if index was not encountered in +// oldmap operands and check_reduction was set true. Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, - Value memref_val, Value index, Value bound, int firstNDims, ValueRange oldmap_operands, Value origmemref, bool &check_reduction) { - assert(oldmap_operands.size() == oldmap.getNumSymbols() + oldmap.getNumDims()); - //Operands which don't correspond to indices + Value memref_val, Value index, Value bound, + int firstNDims, ValueRange oldmap_operands, + Value origmemref, bool &check_reduction) { + assert(oldmap_operands.size() == + oldmap.getNumSymbols() + oldmap.getNumDims()); + // Operands which don't correspond to indices SmallVector operands_without_indices; ssize_t dimidx = -1; for (auto [i, v] : llvm::enumerate(oldmap_operands)) { @@ -125,10 +129,12 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, } assert(i >= firstNDims); if (v != index) { - // Check if the symbol value is read-only or defined in a scope where it is always visible. + // Check if the symbol value is read-only or defined in a scope where it + // is always visible. if (auto ba = dyn_cast(v)) { // check if it dominates the current scope - if (ba.getParentBlock()->getParent()->isAncestor(builder.getBlock()->getParent())) + if (ba.getParentBlock()->getParent()->isAncestor( + builder.getBlock()->getParent())) operands_without_indices.push_back(v); else { assert(false); @@ -138,14 +144,17 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, } else { auto op = v.getDefiningOp(); // check if this dominates the current scope - if (op->getParentRegion()->isAncestor(builder.getBlock()->getParent())) { + if (op->getParentRegion()->isAncestor( + builder.getBlock()->getParent())) { operands_without_indices.push_back(v); } else if (isReadOnly(op)) { // if not, check if it is readnone - // Technically this isn't quite sufficient yet, and does require that the operands to this op are also able to be hoisted, - // but for now we will assume this + // Technically this isn't quite sufficient yet, and does require that + // the operands to this op are also able to be hoisted, but for now we + // will assume this auto op2 = builder.clone(*op); - operands_without_indices.push_back(op2->getResult(cast(v).getResultNumber())); + operands_without_indices.push_back( + op2->getResult(cast(v).getResultNumber())); } else { // if so clone it in the right scope // otherwise set illegal and don't continue @@ -157,15 +166,15 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, } else dimidx = i; } - if((dimidx == -1) && (check_reduction)) + if ((dimidx == -1) && (check_reduction)) check_reduction = true; - else + else check_reduction = false; SmallVector dimReplacements; size_t validSims = 0; size_t validDims = 0; - for (int i=0; i symReplacements; - for (int i=0; i idx_sizes; - for (size_t i=0; i()) idx_sizes.push_back(submap.getSizes()[i]); else llvm_unreachable("Won't reach this case"); - //idx_sizes.push_back(builder.create(origmemref.getLoc(), origmemref, i)); + // idx_sizes.push_back(builder.create(origmemref.getLoc(), + // origmemref, i)); } idx_sizes.push_back(bound); legal = true; SmallVector sizes(idx_sizes.size(), mlir::ShapedType::kDynamic); for (auto sz : idx_sizes) { - // Check if the symbol value is read-only or defined in a scope where it is always visible. + // Check if the symbol value is read-only or defined in a scope where it is + // always visible. if (auto ba = dyn_cast(sz)) { // check if it dominates the current scope - if (ba.getParentBlock()->getParent()->isAncestor(builder.getBlock()->getParent())) + if (ba.getParentBlock()->getParent()->isAncestor( + builder.getBlock()->getParent())) operands_without_indices.push_back(sz); else { llvm::errs() << " value is a non-dominating block arg: " << sz << "\n"; legal = false; - assert(false); + assert(false); return nullptr; } } else { @@ -243,23 +261,27 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, operands_without_indices.push_back(sz); } else if (isReadOnly(op)) { // if not, check if it is readnone - // Technically this isn't quite sufficient yet, and does require that the operands to this op are also able to be hoisted, - // but for now we will assume this + // Technically this isn't quite sufficient yet, and does require that + // the operands to this op are also able to be hoisted, but for now we + // will assume this auto op2 = builder.clone(*op); - operands_without_indices.push_back(op2->getResult(cast(sz).getResultNumber())); + operands_without_indices.push_back( + op2->getResult(cast(sz).getResultNumber())); } else { llvm::errs() << " op is not readonly: " << *op << "\n"; // if so clone it in the right scope // otherwise set illegal and don't continue legal = false; - assert(false); + assert(false); return nullptr; } } } - auto ty = MemRefType::get(sizes, cast(memref_val.getType()).getElementType()); + auto ty = MemRefType::get( + sizes, cast(memref_val.getType()).getElementType()); - return builder.create(memref_val.getLoc(), ty, memref_val, operands_without_indices, map2); + return builder.create( + memref_val.getLoc(), ty, memref_val, operands_without_indices, map2); } // store A[...] @@ -317,7 +339,7 @@ linalg.generic %[[[memref]]] [[[[#map]]]]([[[[operands]]]]) { output_memref = memref_base output_map = subvmap() - compose + compose # uts are memref, map, and operands # outputs are o memref[map(operands)] ==== output_memref[output_map(output_operands)] @@ -367,8 +389,8 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, if (auto SM = dyn_cast(defOp)) { auto submap = SM.getMap(); - //TODO: Do we achieve anything with this compose? - //As lgMap in our case is 1 to 1 identity map + // TODO: Do we achieve anything with this compose? + // As lgMap in our case is 1 to 1 identity map auto composeMap = submap.compose(lgMap); SmallVector operands0; @@ -392,7 +414,7 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, continue; } - //if (auto SV = dyn_cast(defOp)) { + // if (auto SV = dyn_cast(defOp)) { // // TODO update map with the new indexing from here @@ -407,8 +429,10 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, // SmallVector strideExprs; // SmallVector dimOperands; // SmallVector symOperands; - // for (auto &&[first, second] : llvm::zip(SV.getOffsets(), SV.getStrides())) { - // for (auto &&[index, val] : llvm::enumerate(SmallVector({first, second}))) { + // for (auto &&[first, second] : llvm::zip(SV.getOffsets(), + // SV.getStrides())) { + // for (auto &&[index, val] : llvm::enumerate(SmallVector({first, + // second}))) { // auto &exprOutput = (index == 0) ? startExprs : strideExprs; // // Only support constants, symbols, or affine apply as offsets // if (auto cop = val.getDefiningOp()) { @@ -420,7 +444,8 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, // } // if (auto ba = dyn_cast(val)) { // Block *parentBlock = ba.getOwner(); - // if (isa(parentBlock->getParentOp())) { + // if (isa(parentBlock->getParentOp())) { // exprOutput.push_back( // builder.getAffineDimExpr(dimOperands.size())); // dimOperands.push_back(ba); @@ -439,15 +464,18 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, // continue; // } - // //TODO: Maybe it's a case to add, but are we sure we need it for starts and offsets + // //TODO: Maybe it's a case to add, but are we sure we need it for + // starts and offsets // // and not for operands // if (auto apply = dyn_cast(valOp)) { // auto map = apply.getAffineMap(); // auto *scope = affine::getAffineScope(valOp)->getParentOp(); // DominanceInfo DI(scope); // auto map_operands = apply.getOperands(); - // //fully2ComposeAffineMapAndOperands(builder, &map, &map_operands, DI); - //// Instead of using loop step we are using 1 (Assumption as the stride size) + // //fully2ComposeAffineMapAndOperands(builder, &map, &map_operands, + // DI); + //// Instead of using loop step we are using 1 (Assumption as the stride + ///size) // auto newexpr = map.shiftDims(dimOperands.size()) // .shiftSymbols(symOperands.size()); @@ -459,7 +487,8 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, // dimOperands.push_back(apply.getOperands()[i]); // for (size_t i = 0; i < map.getNumSymbols(); i++) - // symOperands.push_back(apply.getOperands()[i + map.getNumDims()]); + // symOperands.push_back(apply.getOperands()[i + + // map.getNumDims()]); // continue; // } @@ -479,7 +508,6 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, // for (size_t i = 0; i < lgMap.getNumSymbols(); i++) // symOperands.push_back(lgOperands[i + lgMap.getNumDims()]); - // SmallVector mergedExprs; // for (auto && [start, stride, idx] : // llvm::zip(startExprs, strideExprs, inputExprs)) { @@ -487,15 +515,16 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, // } // lgMap = - // AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs, loop->getContext()); + // AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs, + // loop->getContext()); // lgOperands.clear(); - // lgOperands.insert(lgOperands.begin(), dimOperands.begin(), dimOperands.end()); - // lgOperands.insert(lgOperands.begin()+lgOperands.size(), symOperands.begin(), symOperands.end()); - // input = SV.getSource(); - // break; + // lgOperands.insert(lgOperands.begin(), dimOperands.begin(), + // dimOperands.end()); + // lgOperands.insert(lgOperands.begin()+lgOperands.size(), + // symOperands.begin(), symOperands.end()); input = SV.getSource(); break; //} - //return failure(); + // return failure(); } assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); return success(); @@ -506,7 +535,7 @@ struct AffineForOpRaising : public OpRewritePattern { LogicalResult matchAndRewrite(affine::AffineForOp loop, PatternRewriter &rewriter) const final { - + auto module = loop->getParentOfType(); // Don't handle accumulations in registers for the moment, we can have @@ -519,7 +548,7 @@ struct AffineForOpRaising : public OpRewritePattern { SmallVector, AffineStoreOp>> stores; SmallVector, GenericOp>> linalgGenerics; bool check_reduction; - + // TODO Also collect all the linalg generics! // Check that the only operations within the region are either: @@ -668,7 +697,6 @@ struct AffineForOpRaising : public OpRewritePattern { // loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), // *ub, *lb); - for (auto &&[conds, lg] : linalgGenerics) { // This captures the indexing map attribute from the linalg.generic being @@ -686,21 +714,22 @@ struct AffineForOpRaising : public OpRewritePattern { // lgMap comes from offset of memref.subview, // lgOperands comes from operands of memref.subview - const AffineMap lgMap0 = cast(indexingMapsAttr[idx]).getAffineMap(); + const AffineMap lgMap0 = + cast(indexingMapsAttr[idx]).getAffineMap(); AffineMap lgMap = lgMap0; SmallVector lgOperands; - for (int i=0; i { return failure(); bool legal = true; - - // Takes input's/output's, affineMap of load/store (here lgMap ?), + + // Takes input's/output's, affineMap of load/store (here lgMap ?), // induction variable corresponding to the loop // Memref corresponding the the memory accessed (in this case subview ?) // loopSize, lower and upper bounds // Get operands for load/store (here ?) to find dependent dim // Gives output newMemref which is a subviewOp, - // newAffineMap which is the LG's indexing map corresponding this inp/output - - // This takes load and store maps and then creates affine.apply+subview+linalg.generic - // For this case: LG within ForOp - + // newAffineMap which is the LG's indexing map corresponding this + // inp/output + + // This takes load and store maps and then creates + // affine.apply+subview+linalg.generic For this case: LG within ForOp - // Inputs should be : load map extracted from subviewOp - // Returns LG with indexingMap and subview with affine.apply - which are correct - - //TODO: Or is it num dims? - //size_t firstNDims = lgMap.getResults().size(); + // Returns LG with indexingMap and subview with affine.apply - which + // are correct + + // TODO: Or is it num dims? + // size_t firstNDims = lgMap.getResults().size(); size_t firstNDims = lgMap.getNumDims(); check_reduction = false; auto newMemref = remap_in_affine_dim( @@ -733,8 +764,8 @@ struct AffineForOpRaising : public OpRewritePattern { if (!legal) return failure(); - auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); - + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims + 1); + // TODO: need to mergre previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); @@ -747,15 +778,16 @@ struct AffineForOpRaising : public OpRewritePattern { if (conds.size() != 0) return failure(); - const AffineMap lgMap0 = cast(indexingMapsAttr[idx]).getAffineMap(); + const AffineMap lgMap0 = + cast(indexingMapsAttr[idx]).getAffineMap(); AffineMap lgMap = lgMap0; - + SmallVector lgOperands; - for (int i=0; i { size_t firstNDims = lgMap.getNumDims(); check_reduction = true; auto newMemref = remap_in_affine_dim( - legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, firstNDims, ValueRange(lgOperands), output, check_reduction); + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, + firstNDims, ValueRange(lgOperands), output, check_reduction); if (!legal) return failure(); - auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims + 1); // TODO: need to merge previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); outputs.push_back(newMemref); @@ -794,22 +827,22 @@ struct AffineForOpRaising : public OpRewritePattern { check_reduction = false; auto newMemref = remap_in_affine_dim( legal, rewriter, load.getAffineMap(), load.getMemref(), - loop.getInductionVar(), loopSize, firstNDims, - load.getMapOperands(), load.getMemref(), check_reduction); + loop.getInductionVar(), loopSize, firstNDims, load.getMapOperands(), + load.getMemref(), check_reduction); if (!legal) return failure(); - auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims + 1); affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); } // TODO Push all of the inputs to the linalg generics (modifying maps as // needed) - //SmallVector outputs; - // Store we may need to reindex into a splat potentially later, but for now - // we'll be lazy + // SmallVector outputs; + // Store we may need to reindex into a splat potentially later, but for now + // we'll be lazy for (auto &&[conds, store] : stores) { // Only support unconditional loads for the moment if (conds.size() != 0) @@ -818,18 +851,18 @@ struct AffineForOpRaising : public OpRewritePattern { bool legal = true; size_t firstNDims = 0; - + check_reduction = true; auto newMemref = remap_in_affine_dim( legal, rewriter, store.getAffineMap(), store.getMemref(), - loop.getInductionVar(), loopSize, firstNDims, - store.getMapOperands(), store.getMemref(), check_reduction); + loop.getInductionVar(), loopSize, firstNDims, store.getMapOperands(), + store.getMemref(), check_reduction); if (!legal) { return failure(); } - auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims+1); + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims + 1); affineMaps.push_back(newAffineMap); outputs.push_back(newMemref); } @@ -837,14 +870,14 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO presently if linalg generic exists, assert there are no load/stores if ((linalgGenerics.size() > 0) && - ((loads.size() != 0) || (stores.size() != 0))) { + ((loads.size() != 0) || (stores.size() != 0))) { assert(false); return failure(); } // TODO assert only zero or one linalg generic exists if (!(linalgGenerics.size() == 1 || linalgGenerics.size() == 0)) { - //assert(false); + // assert(false); return failure(); } @@ -852,21 +885,21 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO if linalg generic exists, make this iterator type prepend to the // existing iterators - //TODO: Just store check is not sufficient, there has to be a check for - //bool is_parallel = stores_map.size() == 0; - // TODO determine if linalg generic, whether to create parallel or reduction by looking at memory patterns of maps + // TODO: Just store check is not sufficient, there has to be a check for + // bool is_parallel = stores_map.size() == 0; + // TODO determine if linalg generic, whether to create parallel or + // reduction by looking at memory patterns of maps if (linalgGenerics.size() == 1) { // determine whether now we write to ourselves } - iteratorTypes.push_back(check_reduction - ? utils::IteratorType::reduction - : utils::IteratorType::parallel); + iteratorTypes.push_back(check_reduction ? utils::IteratorType::reduction + : utils::IteratorType::parallel); if (linalgGenerics.size() == 1) { for (auto attr : linalgGenerics[0].second.getIteratorTypesArray()) - iteratorTypes.push_back(attr); + iteratorTypes.push_back(attr); } StringAttr empty = StringAttr::get(loop.getContext()); @@ -924,8 +957,7 @@ struct AffineForOpRaising : public OpRewritePattern { auto term = genBlock.getTerminator(); mlir::IRMapping map; for (auto arg : genBlock.getArguments()) { - auto arg2 = - blk->addArgument(arg.getType(), arg.getLoc()); + auto arg2 = blk->addArgument(arg.getType(), arg.getLoc()); map.map(arg, arg2); } for (auto &op : genBlock.without_terminator()) { @@ -934,7 +966,7 @@ struct AffineForOpRaising : public OpRewritePattern { for (auto op : term->getOperands()) { toreturn.push_back(map.lookup(op)); } - //llvm::errs() << genOp->getParentOfType() << "\n"; + // llvm::errs() << genOp->getParentOfType() << "\n"; rewriter.eraseOp(genOp); } diff --git a/lib/polygeist/Passes/RemoveIterArgs.cpp b/lib/polygeist/Passes/RemoveIterArgs.cpp index b3b0ac7302a4..0a4784c6c599 100644 --- a/lib/polygeist/Passes/RemoveIterArgs.cpp +++ b/lib/polygeist/Passes/RemoveIterArgs.cpp @@ -29,7 +29,7 @@ struct RemoveSCFIterArgs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override { - + ModuleOp module = forOp->getParentOfType(); if (!forOp.getRegion().hasOneBlock()) return failure(); @@ -44,29 +44,34 @@ struct RemoveSCFIterArgs : public OpRewritePattern { auto init = forOp.getInits()[i]; auto lastOp = yieldOp->getOperand(i); - //General Case(TODO): - //ALGo: - // 1. Create an alloca(stack) variable - // How to know it's dims? It should be based on number of reduction loops - // 2. Initialize it with init value just outside the for loop if init value is non-zero - // 3. memref.load that value in the for loop - // 4. Replace all the uses of the iter_arg with the loaded value - // 5. Add a memref.store for the value to be yielded - // 6. Replace all uses of for-loops yielded value with a single inserted memref.load - //Special case: - //ALGo: - //Optimize away memref.store and memref.load, if the only users of memref.load are memref.store (can use affine-scalrep pass for that ? No it does store to load forwarding) - //What we need is forwarding of local store to final store and deleting the intermediate alloca created. This is only possible if the user of alloca is a storeOp. - // 1. Identify the single store of the for loop result - // 2. Initialize it with iter arg init, outside the for loop. (TODO) - // 3. Do a load from the memref - // 4. move the store to memref inside the loop. + // General Case(TODO): + // ALGo: + // 1. Create an alloca(stack) variable + // How to know it's dims? It should be based on number of reduction + // loops + // 2. Initialize it with init value just outside the for loop if init + // value is non-zero + // 3. memref.load that value in the for loop + // 4. Replace all the uses of the iter_arg with the loaded value + // 5. Add a memref.store for the value to be yielded + // 6. Replace all uses of for-loops yielded value with a single inserted + // memref.load + // Special case: + // ALGo: + // Optimize away memref.store and memref.load, if the only users of + // memref.load are memref.store (can use affine-scalrep pass for that ? No + // it does store to load forwarding) What we need is forwarding of local + // store to final store and deleting the intermediate alloca created. This + // is only possible if the user of alloca is a storeOp. + // 1. Identify the single store of the for loop result + // 2. Initialize it with iter arg init, outside the for loop. (TODO) + // 3. Do a load from the memref + // 4. move the store to memref inside the loop. auto result = forOp.getResult(i); - if(result.hasOneUse()) { + if (result.hasOneUse()) { auto storeOp = dyn_cast(*result.getUsers().begin()); - if(storeOp) - { + if (storeOp) { { rewriter.setInsertionPointToStart(forOp.getBody()); auto memrefLoad = rewriter.create( @@ -75,26 +80,25 @@ struct RemoveSCFIterArgs : public OpRewritePattern { } { rewriter.setInsertionPoint(yieldOp); - rewriter.create(forOp.getLoc(), lastOp, storeOp.getMemref(), - storeOp.getIndices()); - storeOp.erase(); + rewriter.create(forOp.getLoc(), lastOp, + storeOp.getMemref(), + storeOp.getIndices()); + storeOp.erase(); } - } - else{ + } else { return failure(); } } - //else{ - // alloca = rewriter.create( - // forOp.getLoc(), MemRefType::get(ArrayRef(), forOp.getType()), - // ValueRange()); - // //Skipping init for now - + // else{ + // alloca = rewriter.create( + // forOp.getLoc(), MemRefType::get(ArrayRef(), + // forOp.getType()), ValueRange()); + // //Skipping init for now // auto memrefLoad = rewriter.create( // forOp.getLoc(), alloca.getMemref(), op.getIndices()); // rewriter.replaceOp(op, memrefLoad.getResult()); - + // rewriter.create(forOp.getLoc(), lastOp, alloca, // forOp.getBody()->getArguments()); @@ -102,7 +106,7 @@ struct RemoveSCFIterArgs : public OpRewritePattern { //} rewriter.setInsertionPointToStart(forOp.getBody()); - //rewriter.replaceAllUsesWith(ba, replacementIV); + // rewriter.replaceAllUsesWith(ba, replacementIV); changed = true; } @@ -110,19 +114,18 @@ struct RemoveSCFIterArgs : public OpRewritePattern { return failure(); rewriter.setInsertionPoint(forOp); - auto newForOp = rewriter.create(loc, forOp.getLowerBound(), - forOp.getUpperBound(), - forOp.getStep()); + auto newForOp = rewriter.create( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); if (!newForOp.getRegion().empty()) newForOp.getRegion().front().erase(); assert(newForOp.getRegion().empty()); rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), newForOp.getRegion().begin()); - //Delete region args + // Delete region args llvm::BitVector toDelete(numIterArgs + 1); for (unsigned i = 0; i < numIterArgs; i++) - toDelete[i + 1] = true; + toDelete[i + 1] = true; newForOp.getBody()->eraseArguments(toDelete); SmallVector newYields; @@ -130,7 +133,7 @@ struct RemoveSCFIterArgs : public OpRewritePattern { ValueRange empty; rewriter.setInsertionPoint(yieldOp); auto newYieldOp = rewriter.create(loc); - //rewriter.replaceOpWithNewOp(yieldOp, newYieldOp); + // rewriter.replaceOpWithNewOp(yieldOp, newYieldOp); rewriter.eraseOp(yieldOp); } @@ -145,7 +148,7 @@ struct RemoveAffineIterArgs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(affine::AffineForOp forOp, PatternRewriter &rewriter) const override { - + ModuleOp module = forOp->getParentOfType(); if (!forOp.getRegion().hasOneBlock()) return failure(); @@ -154,63 +157,70 @@ struct RemoveAffineIterArgs : public OpRewritePattern { bool changed = false; llvm::SetVector removed; llvm::MapVector steps; - auto yieldOp = cast(forOp.getBody()->getTerminator()); + auto yieldOp = + cast(forOp.getBody()->getTerminator()); for (unsigned i = 0; i < numIterArgs; i++) { auto ba = forOp.getRegionIterArgs()[i]; auto init = forOp.getInits()[i]; auto lastOp = yieldOp->getOperand(i); - //General Case(TODO): - //ALGo: - // 1. Create an alloca(stack) variable - // How to know it's dims? It should be based on number of reduction loops - // 2. Initialize it with init value just outside the for loop if init value is non-zero - // 3. memref.load that value in the for loop - // 4. Replace all the uses of the iter_arg with the loaded value - // 5. Add a memref.store for the value to be yielded - // 6. Replace all uses of for-loops yielded value with a single inserted memref.load - //Special case: - //ALGo: - //Optimize away memref.store and memref.load, if the only users of memref.load are memref.store (can use affine-scalrep pass for that ? No it does store to load forwarding) - //What we need is forwarding of local store to final store and deleting the intermediate alloca created. This is only possible if the user of alloca is a storeOp. - // 1. Identify the single store of the for loop result - // 2. Initialize it with iter arg init, outside the for loop. (TODO) - // 3. Do a load from the memref - // 4. move the store to memref inside the loop. + // General Case(TODO): + // ALGo: + // 1. Create an alloca(stack) variable + // How to know it's dims? It should be based on number of reduction + // loops + // 2. Initialize it with init value just outside the for loop if init + // value is non-zero + // 3. memref.load that value in the for loop + // 4. Replace all the uses of the iter_arg with the loaded value + // 5. Add a memref.store for the value to be yielded + // 6. Replace all uses of for-loops yielded value with a single inserted + // memref.load + // Special case: + // ALGo: + // Optimize away memref.store and memref.load, if the only users of + // memref.load are memref.store (can use affine-scalrep pass for that ? No + // it does store to load forwarding) What we need is forwarding of local + // store to final store and deleting the intermediate alloca created. This + // is only possible if the user of alloca is a storeOp. + // 1. Identify the single store of the for loop result + // 2. Initialize it with iter arg init, outside the for loop. (TODO) + // 3. Do a load from the memref + // 4. move the store to memref inside the loop. auto result = forOp.getResult(i); - if(result.hasOneUse()) { - auto storeOp = dyn_cast(*result.getUsers().begin()); - if(storeOp) - { + if (result.hasOneUse()) { + auto storeOp = + dyn_cast(*result.getUsers().begin()); + if (storeOp) { { rewriter.setInsertionPointToStart(forOp.getBody()); auto memrefLoad = rewriter.create( - forOp.getLoc(), storeOp.getMemref(), storeOp.getMap(), storeOp.getMapOperands()); + forOp.getLoc(), storeOp.getMemref(), storeOp.getMap(), + storeOp.getMapOperands()); rewriter.replaceAllUsesWith(ba, memrefLoad.getResult()); } { rewriter.setInsertionPoint(yieldOp); - rewriter.create(forOp.getLoc(), lastOp, storeOp.getMemref(), - storeOp.getMap(), storeOp.getMapOperands()); - storeOp.erase(); + rewriter.create( + forOp.getLoc(), lastOp, storeOp.getMemref(), storeOp.getMap(), + storeOp.getMapOperands()); + storeOp.erase(); } - } - else{ + } else { return failure(); } } - //else{ - // alloca = rewriter.create( - // forOp.getLoc(), MemRefType::get(ArrayRef(), forOp.getType()), - // ValueRange()); - // //Skipping init for now - + // else{ + // alloca = rewriter.create( + // forOp.getLoc(), MemRefType::get(ArrayRef(), + // forOp.getType()), ValueRange()); + // //Skipping init for now // auto memrefLoad = rewriter.create( // forOp.getLoc(), alloca.getMemref(), op.getIndices()); // rewriter.replaceOp(op, memrefLoad.getResult()); - + // rewriter.create(forOp.getLoc(), lastOp, alloca, // forOp.getBody()->getArguments()); @@ -218,7 +228,7 @@ struct RemoveAffineIterArgs : public OpRewritePattern { //} rewriter.setInsertionPointToStart(forOp.getBody()); - //rewriter.replaceAllUsesWith(ba, replacementIV); + // rewriter.replaceAllUsesWith(ba, replacementIV); changed = true; } @@ -226,20 +236,21 @@ struct RemoveAffineIterArgs : public OpRewritePattern { return failure(); rewriter.setInsertionPoint(forOp); - auto newForOp = rewriter.create(loc, forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), - forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), - forOp.getStep()); - + auto newForOp = rewriter.create( + loc, forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), + forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), + forOp.getStep()); + if (!newForOp.getRegion().empty()) newForOp.getRegion().front().erase(); assert(newForOp.getRegion().empty()); rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), newForOp.getRegion().begin()); - //Delete region args + // Delete region args llvm::BitVector toDelete(numIterArgs + 1); for (unsigned i = 0; i < numIterArgs; i++) - toDelete[i + 1] = true; + toDelete[i + 1] = true; newForOp.getBody()->eraseArguments(toDelete); SmallVector newYields; @@ -247,7 +258,8 @@ struct RemoveAffineIterArgs : public OpRewritePattern { ValueRange empty; rewriter.setInsertionPoint(yieldOp); auto newYieldOp = rewriter.create(loc); - //rewriter.replaceOpWithNewOp(yieldOp, newYieldOp); + // rewriter.replaceOpWithNewOp(yieldOp, + // newYieldOp); rewriter.eraseOp(yieldOp); } @@ -259,8 +271,7 @@ struct RemoveAffineIterArgs : public OpRewritePattern { }; namespace { -struct RemoveIterArgs - : public RemoveIterArgsBase { +struct RemoveIterArgs : public RemoveIterArgsBase { void runOnOperation() override { GreedyRewriteConfig config; @@ -269,11 +280,11 @@ struct RemoveIterArgs ConversionTarget target(*context); patterns.insert(patterns.getContext()); patterns.insert(patterns.getContext()); - + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { - signalPassFailure(); - return; + config))) { + signalPassFailure(); + return; } } }; diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 7759db83c573..b5aba75c9264 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -19,9 +19,9 @@ #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" From cf9f9531d50f94d75751caff67799bdcac2b663a Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Mon, 13 Jan 2025 13:48:39 -0800 Subject: [PATCH 37/77] Ran git clang format locally to fix regression failures --- lib/polygeist/Passes/RaiseToLinalg.cpp | 2 +- debufferize.mlir => test/polygeist-opt/debufferize.mlir | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename debufferize.mlir => test/polygeist-opt/debufferize.mlir (100%) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 3d99e6f67029..fee0e4d157a7 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -475,7 +475,7 @@ LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, // //fully2ComposeAffineMapAndOperands(builder, &map, &map_operands, // DI); //// Instead of using loop step we are using 1 (Assumption as the stride - ///size) + /// size) // auto newexpr = map.shiftDims(dimOperands.size()) // .shiftSymbols(symOperands.size()); diff --git a/debufferize.mlir b/test/polygeist-opt/debufferize.mlir similarity index 100% rename from debufferize.mlir rename to test/polygeist-opt/debufferize.mlir From f10c47a612f93b53ff2c02c1935d801faeb9b0eb Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 16 Jan 2025 19:32:53 -0800 Subject: [PATCH 38/77] Working implementation for function args memrefType with noinline attribute --- lib/polygeist/Ops.cpp | 32 +++- lib/polygeist/Passes/LinalgDebufferize.cpp | 171 +++++++++++++++++---- test/polygeist-opt/debufferize.mlir | 47 ++++-- 3 files changed, 197 insertions(+), 53 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index c65e5a9d3afb..7ec7a9d2af3e 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -822,19 +822,37 @@ bool mayAlias(Value v, Value v2) { if ((isAlloca[0] || isGlobal[0]) && (isAlloca[1] || isGlobal[1])) return false; - bool isArg[2]; - isArg[0] = v.isa() && - isa( - v.cast().getOwner()->getParentOp()); + bool isArg[2] = {false, false}; + bool isNoAliasArg[2] = {false, false}; - isArg[1] = v.isa() && - isa( - v.cast().getOwner()->getParentOp()); + if (auto ba = dyn_cast(v)) { + if (auto fn = dyn_cast(ba.getOwner()->getParentOp())) { + isArg[0] = true; + if (fn.getArgAttr(ba.getArgNumber(), LLVM::LLVMDialect::getNoAliasAttrName())) { + isNoAliasArg[0] = true; + } + } + } + + if (auto ba = dyn_cast(v2)) { + if (auto fn = dyn_cast(ba.getOwner()->getParentOp())) { + isArg[0] = true; + if (fn.getArgAttr(ba.getArgNumber(), LLVM::LLVMDialect::getNoAliasAttrName())) { + isNoAliasArg[0] = true; + } + } + } // Stack allocations cannot have been passed as an argument. if ((isAlloca[0] && isArg[1]) || (isAlloca[1] && isArg[0])) return false; + if ((isArg[0] && isNoAliasArg[1]) || (isArg[1] && isNoAliasArg[0])) + return false; + + if ((isGlobal[0] && isNoAliasArg[1]) || (isGlobal[1] && isNoAliasArg[0])) + return false; + // Non captured base allocas cannot conflict with another base value. if (isAlloca[0] && !isCaptured(v)) return false; diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 82310052c509..d6255c790954 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "polygeist/Ops.h" #include "polygeist/Passes/Passes.h" #include "llvm/Support/Debug.h" @@ -29,10 +30,25 @@ using namespace linalg; using namespace tensor; using namespace bufferization; -std::vector getSortedUsers(Operation *op) { - if (!op) - return {}; +bool isCaptured(Value v, Operation *potentialUser = nullptr, + bool *seenuse = nullptr); + +std::vector getSortedUsers(Value val) { + std::vector users; + for (Operation *user : val.getUsers()) { + users.push_back(user); + } + + // Sort the users based on their topological order + std::sort(users.begin(), users.end(), [](Operation *a, Operation *b) { + return a->isBeforeInBlock(b); + }); + + return users; +} + +std::vector getSortedUsers(Operation *op) { // Find the parent function auto funcOp = op->getParentOfType(); if (!funcOp) @@ -54,6 +70,44 @@ std::vector getSortedUsers(Operation *op) { return sortedUsers; } +struct debufferizationAllocaRemoval : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AllocaOp allocaOp, + PatternRewriter &rewriter) const final { + Value allocaResult = allocaOp.getResult(); + bool userToTensorOp = false; + bool userCopyOp = false; + bool userOtherOp = false; + Value copyOp; + Value toTensorOp; + for (Operation *user : allocaResult.getUsers()) { + if (isa(user)) { + userToTensorOp = true; + toTensorOp = user->getResult(0); + } + else if (isa(user)) { + userCopyOp = true; + copyOp = user->getResult(0); + } + else + userOtherOp = true; + } + + if(!(!userOtherOp&&userCopyOp&&userToTensorOp)) + return failure(); + + auto emptyTensor = + rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), + allocaOp.getType().getElementType()); + + rewriter.replaceAllUsesWith(toTensorOp, emptyTensor.getResult()); + rewriter.eraseOp(copyOp.getDefiningOp()); + rewriter.eraseOp(toTensorOp.getDefiningOp()); + return success(); + } +}; + struct LinalgDebufferization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -68,30 +122,70 @@ struct LinalgDebufferization : public OpRewritePattern { // in ins and outs llvm::SmallPtrSet processedGenericOps; - LogicalResult passResult = success(); - funcOp.walk([&](mlir::memref::AllocaOp allocaOp) -> WalkResult { - auto module = allocaOp->getParentOfType(); - rewriter.setInsertionPointAfter(allocaOp); + LogicalResult passResult = failure(); + + auto handleMemref = [&](Value memVal) -> LogicalResult { + auto module = memVal.getParentRegion()->getParentOfType(); + + if (!memVal.getType().isa()) { + return failure(); + } + + bool isNoalias = false; + if (auto mem = memVal.getDefiningOp()) { + if (auto defOp = memVal.getDefiningOp()) {//if (mem has allocation like) { + if (isa(defOp)) { + isNoalias = true; + } + } + } else if (auto ba = dyn_cast(memVal)) { + if (auto fn = dyn_cast(ba.getOwner()->getParentOp())) { + if (fn.getArgAttr(ba.getArgNumber(), LLVM::LLVMDialect::getNoAliasAttrName())) { + isNoalias = true; + } + } + } else if (memVal.getDefiningOp() || + memVal.getDefiningOp()) { + isNoalias = true; //TODO: is this correct? + } + + // if we are no alias we can just look at all users of the value + // if we are not noalias, or we are captured, then we have to look at all users that + // could read or write + if (!isNoalias) { //|| isCaptured(memVal)) { TODO: need to improve isCaptured to include linalg.generic + return failure(); //|| isCaptured(memVal)) { TODO: need to improve isCaptured to include linalg.generic + } + + MemRefType memrefType; + if (auto blockArg = memVal.dyn_cast()) { + memrefType = blockArg.getType().dyn_cast(); + } else if (auto allocaOp = memVal.getDefiningOp()) { + memrefType = allocaOp.getType(); + } else { + return failure(); + } + + + rewriter.setInsertionPointAfterValue(memVal); auto tensorType = RankedTensorType::get( - allocaOp.getType().getShape(), allocaOp.getType().getElementType()); + memrefType.getShape(), memrefType.getElementType()); - // Check to see if only linalg.generic are users of the allocaOp for now. + // Check to see if only linalg.generic are users of the Value op for now. // TODO: Extend this - if (!llvm::all_of(allocaOp->getUsers(), [](Operation *op) { + if (!llvm::all_of(memVal.getUsers(), [](Operation *op) { return isa(op); })) { - passResult = failure(); - return WalkResult::interrupt(); + return failure(); } // auto emptyTensor = // rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), // allocaOp.getType().getElementType()); auto toTensorOp = rewriter.create( - allocaOp.getLoc(), tensorType, allocaOp); + memVal.getLoc(), tensorType, memVal); Value currentTensor = toTensorOp; - auto sortedUsers = getSortedUsers(allocaOp); + auto sortedUsers = getSortedUsers(memVal); // Check if allocaOp is an output in current genericOp for (auto user : sortedUsers) { @@ -109,17 +203,17 @@ struct LinalgDebufferization : public OpRewritePattern { ArrayAttr indexingMaps = genericOp.getIndexingMaps(); for (auto input : genericOp.getInputs()) { - newInputs.push_back(input == allocaOp ? currentTensor : input); + newInputs.push_back(input == memVal ? currentTensor : input); } // ArrayRef resultTypes; int newCurrentTensorIndex = -1; int index = 0; for (auto output : genericOp.getOutputs()) { - newOutputs.push_back(output == allocaOp ? currentTensor : output); - resultTypes.push_back(output == allocaOp ? currentTensor.getType() + newOutputs.push_back(output == memVal ? currentTensor : output); + resultTypes.push_back(output == memVal ? currentTensor.getType() : output.getType()); - if (output == allocaOp) { + if (output == memVal) { newCurrentTensorIndex = index; } index++; @@ -138,34 +232,48 @@ struct LinalgDebufferization : public OpRewritePattern { newGenericOp.getRegion().end()); // Replace all uses of original generic op with the new one - // int idxOldGeneric=0; - // int idxNewGeneric=0; for (unsigned i = 0; i < genericOp->getNumResults(); ++i) { - // if(i == newCurrentTensorIndex) { - // idxNewGeneric++; - // } - // genericOp->getResult(idxOldGeneric).replaceAllUsesWith(newGenericOp->getResult(idxNewGeneric)); - // idxOldGeneric++; - // idxNewGeneric++; genericOp->getResult(i).replaceAllUsesWith( newGenericOp->getResult(i)); } // Delete the original genericOp - opsToDelete.push_back(genericOp.getOperation()); if (newCurrentTensorIndex != -1) currentTensor = newGenericOp.getResult(newCurrentTensorIndex); processedGenericOps.insert(genericOp.getOperation()); + // Delete the original genericOp + //genericOp.erase(); + //WalkResult::interrupt(); + opsToDelete.push_back(genericOp.getOperation()); } } auto toMemrefOp = rewriter.create( - allocaOp.getLoc(), allocaOp.getType(), currentTensor); - rewriter.create(allocaOp.getLoc(), toMemrefOp, allocaOp); + memVal.getLoc(), memrefType, currentTensor); + rewriter.create(memVal.getLoc(), toMemrefOp, memVal); // opsToDelete.push_back(allocaOp.getOperation()); - return WalkResult::advance(); - }); + return success(); + }; + + + bool changed; + do { + changed = funcOp.walk([&](memref::AllocaOp alloca) { + //if (handleMemref(alloca.getResult()).succeeded()) + // return WalkResult::advance(); + //return WalkResult::interrupt(); + handleMemref(alloca.getResult()).succeeded(); + return WalkResult::advance(); + }).wasInterrupted(); + + if (changed) + passResult = success(); + } while (changed); + + if (llvm::any_of(llvm::map_range(funcOp.getArguments(), handleMemref), [](LogicalResult res) {return res.succeeded();})) + + passResult = success(); for (Operation *op : opsToDelete) { op->erase(); } @@ -184,6 +292,7 @@ struct LinalgDebufferize : public LinalgDebufferizeBase { void LinalgDebufferize::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); + //patterns.insert(&getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config); diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index 96f278038f9f..a4961ba7d39e 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -23,21 +23,38 @@ module @in_place_add{ } } -module @conv_2 { - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - %0 = memref.alloca() : memref<515x67xi32> - %1 = memref.alloca() : memref<4x4xi32> - %2 = memref.alloca() : memref<512x64xi32> - linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { - ^bb0(%in: i32, %in_0: i32, %out: i32): - %3 = arith.muli %in, %in_0 : i32 - %4 = arith.addi %out, %3 : i32 - linalg.yield %4 : i32 - } - return %c0_i32 : i32 - } -} +// module @in_place_add2{ +// func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32) { +// %c0 = arith.constant 0 : index +// //%buffer = memref.alloca() : memref<128xf32> +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buffer : memref<128xf32>) +// outs(%buffer : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// linalg.yield %sum : f32 +// } +// return +// } +// } + +// module @conv_2 { +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.alloca() : memref<515x67xi32> +// %1 = memref.alloca() : memref<4x4xi32> +// %2 = memref.alloca() : memref<512x64xi32> +// linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { +// ^bb0(%in: i32, %in_0: i32, %out: i32): +// %3 = arith.muli %in, %in_0 : i32 +// %4 = arith.addi %out, %3 : i32 +// linalg.yield %4 : i32 +// } +// return %c0_i32 : i32 +// } +// } module @harris_score_with_gradient_extra_kernel { memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> From 490f924a64f624c6480e6263d2be5a24a81f0a8e Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 16 Jan 2025 20:00:41 -0800 Subject: [PATCH 39/77] Added debufferization Alloc Removal pass, add working examples with llvm.noalias inputs to func --- lib/polygeist/Passes/LinalgDebufferize.cpp | 17 +- test/polygeist-opt/debufferize.mlir | 199 ++++++++++----------- 2 files changed, 105 insertions(+), 111 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index d6255c790954..51fb17d75087 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -79,16 +79,16 @@ struct debufferizationAllocaRemoval : public OpRewritePattern bool userToTensorOp = false; bool userCopyOp = false; bool userOtherOp = false; - Value copyOp; - Value toTensorOp; + memref::CopyOp copyOp; + bufferization::ToTensorOp toTensorOp; for (Operation *user : allocaResult.getUsers()) { if (isa(user)) { userToTensorOp = true; - toTensorOp = user->getResult(0); + toTensorOp = cast(user); } else if (isa(user)) { userCopyOp = true; - copyOp = user->getResult(0); + copyOp = cast(user); } else userOtherOp = true; @@ -101,9 +101,10 @@ struct debufferizationAllocaRemoval : public OpRewritePattern rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), allocaOp.getType().getElementType()); - rewriter.replaceAllUsesWith(toTensorOp, emptyTensor.getResult()); - rewriter.eraseOp(copyOp.getDefiningOp()); - rewriter.eraseOp(toTensorOp.getDefiningOp()); + rewriter.replaceAllUsesWith(toTensorOp.getResult(), emptyTensor.getResult()); + + rewriter.eraseOp(copyOp); + rewriter.eraseOp(toTensorOp); return success(); } }; @@ -292,7 +293,7 @@ struct LinalgDebufferize : public LinalgDebufferizeBase { void LinalgDebufferize::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); - //patterns.insert(&getContext()); + patterns.insert(&getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config); diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index a4961ba7d39e..5490c75bd86f 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -6,109 +6,102 @@ #map19 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> #map22 = affine_map<(d0, d1) -> (d1, d0)> -module @in_place_add{ - func.func @in_place_add(%value: f32) { - %c0 = arith.constant 0 : index - %buffer = memref.alloca() : memref<128xf32> - linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"] - } ins(%buffer : memref<128xf32>) - outs(%buffer : memref<128xf32>) { - ^bb0(%in: f32, %out: f32): - %sum = arith.addf %in, %value : f32 - linalg.yield %sum : f32 - } - return - } -} + module @in_place_add{ + func.func @in_place_add(%value: f32) { + %c0 = arith.constant 0 : index + %buffer = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } -// module @in_place_add2{ -// func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32) { -// %c0 = arith.constant 0 : index -// //%buffer = memref.alloca() : memref<128xf32> -// linalg.generic { -// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], -// iterator_types = ["parallel"] -// } ins(%buffer : memref<128xf32>) -// outs(%buffer : memref<128xf32>) { -// ^bb0(%in: f32, %out: f32): -// %sum = arith.addf %in, %value : f32 -// linalg.yield %sum : f32 -// } -// return -// } -// } + module @in_place_add2{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } -// module @conv_2 { -// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c0_i32 = arith.constant 0 : i32 -// %0 = memref.alloca() : memref<515x67xi32> -// %1 = memref.alloca() : memref<4x4xi32> -// %2 = memref.alloca() : memref<512x64xi32> -// linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { -// ^bb0(%in: i32, %in_0: i32, %out: i32): -// %3 = arith.muli %in, %in_0 : i32 -// %4 = arith.addi %out, %3 : i32 -// linalg.yield %4 : i32 -// } -// return %c0_i32 : i32 -// } -// } + module @conv_2 { + func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %3 = arith.muli %in, %in_0 : i32 + %4 = arith.addi %out, %3 : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } + } -module @harris_score_with_gradient_extra_kernel { - memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> - memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> - memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c4_i32 = arith.constant 4 : i32 - %c0_i32 = arith.constant 0 : i32 - %alloca = memref.alloca() : memref<512x512xi32> - %alloca_0 = memref.alloca() : memref<512x512xi32> - %alloca_1 = memref.alloca() : memref<512x512xi32> - %alloca_2 = memref.alloca() : memref<516x516xi32> - %alloca_3 = memref.alloca() : memref<516x516xi32> - %alloca_4 = memref.alloca() : memref<518x518xi32> - %score = memref.alloca() : memref<512x512xi32> - %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> - %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> - %2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> - // 2nd variant - // %0 = memref.alloca() : memref<3x3xi32> - // %1 = memref.alloca() : memref<3x3xi32> - // %2 = memref.alloca() : memref<5x5xi32> - linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { - ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): - %4 = arith.muli %in, %in_5 : i32 - %5 = arith.addi %out_7, %4 : i32 - %6 = arith.muli %in, %in_6 : i32 - %7 = arith.addi %out, %6 : i32 - linalg.yield %7, %5 : i32, i32 - } - linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { - ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): - %4 = arith.muli %in, %in : i32 - %5 = arith.muli %4, %in_6 : i32 - %6 = arith.addi %out_8, %5 : i32 - %7 = arith.muli %in_5, %in_5 : i32 - %8 = arith.muli %7, %in_6 : i32 - %9 = arith.addi %out_7, %8 : i32 - %10 = arith.muli %in, %in_5 : i32 - %11 = arith.muli %10, %in_6 : i32 - %12 = arith.addi %out, %11 : i32 - linalg.yield %12, %9, %6 : i32, i32, i32 - } - linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { - ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): - %4 = arith.muli %in, %in_5 : i32 - %5 = arith.muli %in_6, %in_6 : i32 - %6 = arith.subi %4, %5 : i32 - %7 = arith.addi %in, %in_5 : i32 - %8 = arith.muli %7, %c4_i32 : i32 - %9 = arith.muli %8, %7 : i32 - %10 = arith.subi %6, %9 : i32 - linalg.yield %10 : i32 - } - return %c0_i32 : i32 - } - } \ No newline at end of file + module @harris_score_with_gradient_extra_kernel { + //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + func.func @main(%0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + //%score = memref.alloca() : memref<512x512xi32> + //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.addi %out_7, %4 : i32 + %6 = arith.muli %in, %in_6 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7, %5 : i32, i32 + } + linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): + %4 = arith.muli %in, %in : i32 + %5 = arith.muli %4, %in_6 : i32 + %6 = arith.addi %out_8, %5 : i32 + %7 = arith.muli %in_5, %in_5 : i32 + %8 = arith.muli %7, %in_6 : i32 + %9 = arith.addi %out_7, %8 : i32 + %10 = arith.muli %in, %in_5 : i32 + %11 = arith.muli %10, %in_6 : i32 + %12 = arith.addi %out, %11 : i32 + linalg.yield %12, %9, %6 : i32, i32, i32 + } + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.muli %in_6, %in_6 : i32 + %6 = arith.subi %4, %5 : i32 + %7 = arith.addi %in, %in_5 : i32 + %8 = arith.muli %7, %c4_i32 : i32 + %9 = arith.muli %8, %7 : i32 + %10 = arith.subi %6, %9 : i32 + linalg.yield %10 : i32 + } + return %c0_i32 : i32 + } + } \ No newline at end of file From e20708c5c04c4e0e7cbc3f8537610944212f1366 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 30 Jan 2025 20:33:32 -0800 Subject: [PATCH 40/77] Added support for debufferization across nested regions - working for scf.if --- lib/polygeist/Ops.cpp | 2 + lib/polygeist/Passes/LinalgDebufferize.cpp | 346 ++++++++++++++++++++- test/polygeist-opt/debufferize.mlir | 216 +++++++++---- 3 files changed, 485 insertions(+), 79 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 7ec7a9d2af3e..07f0cab20f0c 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -674,6 +674,8 @@ bool isCaptured(Value v, Operation *potentialUser = nullptr, for (auto u : v.getUsers()) { if (seenuse && u == potentialUser) *seenuse = true; + if (isa(u)) + continue; if (isa(u)) continue; diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 51fb17d75087..7c2a57405d8e 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -34,15 +34,214 @@ using namespace bufferization; bool isCaptured(Value v, Operation *potentialUser = nullptr, bool *seenuse = nullptr); +bool isAncestor(Operation *potentialAncestor, Operation *op) { + Operation *current = op->getParentOp(); + while (current != nullptr) { + if (current == potentialAncestor) + return true; + current = current->getParentOp(); + } + return false; +} + +//Checks if a comes before b +bool comesBefore(Operation *a, Operation *b) { + if (a == b) return false; + + if (isAncestor(a, b)) return true; + if (isAncestor(b, a)) return false; + + //Block *aBlock = a->getBlock(); + //Block *bBlock = b->getBlock(); + + //// Same block: compare operation order + //if (aBlock == bBlock) { + // for (Operation &op : aBlock->getOperations()) { + // if (&op == a) return true; + // if (&op == b) return false; + // } + // llvm_unreachable("Operations not found in their parent block"); + //} + + //// Different blocks: compare region hierarchy + //Region *aRegion = aBlock->getParent(); + //Region *bRegion = bBlock->getParent(); + + //// Same region: compare block order + //if (aRegion == bRegion) { + // //auto aBlockIt = std::find(aRegion->begin(), aRegion->end(), aBlock); + // //auto bBlockIt = std::find(aRegion->begin(), aRegion->end(), bBlock); + // //return aBlockIt < bBlockIt; + // //const int aIndex = std::distance(aRegion->begin(), aRegion->find(aBlock)); + // //const int bIndex = std::distance(aRegion->begin(), aRegion->find(bBlock)); + // //return aIndex < bIndex; + // auto get_block_pos = [](Region *region, Block *block) { + // auto &blocks = region->getBlocks(); + // auto it = llvm::find_if(blocks, [block](Block &b) { + // return &b == block; // Address comparison + // }); + // assert(it != blocks.end() && "Block not found in region"); + // return std::distance(blocks.begin(), it); + // //return std::distance(region->getBlocks().begin(), + // // llvm::find(region->getBlocks(), block)); + // }; + // return get_block_pos(aRegion, aBlock) < + // get_block_pos(aRegion, bBlock); + //} + + //// Different regions: compare parent operations + //Operation *aParent = aRegion->getParentOp(); + //Operation *bParent = bRegion->getParentOp(); + + //// Same parent op: compare region order + //if (aParent == bParent) { + // //auto aRegionIt = std::find(aParent->getRegions().begin(), + // // aParent->getRegions().end(), aRegion); + // //auto bRegionIt = std::find(bParent->getRegions().begin(), + // // bParent->getRegions().end(), bRegion); + // //return aRegionIt < bRegionIt; + // //auto get_region_position = [](Operation *parent, Region *target) { + // //return std::distance( + // // parent->getRegions.begin(), + // // llvm::find_if(parent->getRegions(), [&](Region &r) { + // // return &r == target; // Compare region addresses + // // }) + // // ); + // //}; + + // auto get_region_position = [](Operation *parent, Region *target) { + // auto regions = parent->getRegions(); // Get reference to region list + // auto begin = regions.begin(); + // auto it = llvm::find_if(regions, [&](Region &r) { + // return &r == target; + // }); + // return std::distance(begin, it); + // }; + // return get_region_position(aParent, aRegion) < + // get_region_position(aParent, bRegion); + //} + + Operation *aParent = a->getParentOp(); + Operation *bParent = b->getParentOp(); + // Walk up b's hierarchy until we reach a's level + Operation *bAncestor = b; + //We traverse B's ancestors here + while (Operation *parent = bAncestor->getParentOp()) { + if (parent == aParent) { + // Compare positions within aParent's regions/blocks + Region *aRegion = a->getParentRegion(); + Region *bRegion = bAncestor->getParentRegion(); + + if (aRegion == bRegion) { + // Same region: compare block order + Block *aBlock = a->getBlock(); + Block *bBlock = bAncestor->getBlock(); + if (aBlock != bBlock) { + auto get_block_pos = [](Region *region, Block *block) { + auto &blocks = region->getBlocks(); + auto it = llvm::find_if(blocks, [block](Block &b) { + return &b == block; // Address comparison + }); + assert(it != blocks.end() && "Block not found in region"); + return std::distance(blocks.begin(), it); + }; + return get_block_pos(aRegion, aBlock) < + get_block_pos(bRegion, bBlock); + }; + // Same block: compare operation order + return a->isBeforeInBlock(bAncestor); + } + + // Different regions: compare region order + auto compareRegions = [parent](Region *x, Region *y) { + auto get_region_position = [](Operation *parent, Region *target) { + auto regions = parent->getRegions(); // Get reference to region list + auto begin = regions.begin(); + auto it = llvm::find_if(regions, [&](Region &r) { + return &r == target; + }); + return std::distance(begin, it); + }; + return get_region_position(parent, x) < + get_region_position(parent, y); + }; + return compareRegions(aRegion, bRegion); + } + bAncestor = parent; + } + + Operation *aAncestor = a; + //We traverse A's ancestors here + while (Operation *parent = aAncestor->getParentOp()) { + if (parent == bParent) { + // Compare positions within aParent's regions/blocks + Region *bRegion = b->getParentRegion(); + Region *aRegion = aAncestor->getParentRegion(); + + if (aRegion == bRegion) { + // Same region: compare block order + Block *bBlock = b->getBlock(); + Block *aBlock = aAncestor->getBlock(); + if (aBlock != bBlock) { + auto get_block_pos = [](Region *region, Block *block) { + auto &blocks = region->getBlocks(); + auto it = llvm::find_if(blocks, [block](Block &b) { + return &b == block; // Address comparison + }); + assert(it != blocks.end() && "Block not found in region"); + return std::distance(blocks.begin(), it); + }; + return !(get_block_pos(bRegion, bBlock) < + get_block_pos(aRegion, aBlock)); + }; + // Same block: compare operation order + return !b->isBeforeInBlock(aAncestor); + } + + // Different regions: compare region order + auto compareRegions = [parent](Region *x, Region *y) { + auto get_region_position = [](Operation *parent, Region *target) { + auto regions = parent->getRegions(); // Get reference to region list + auto begin = regions.begin(); + auto it = llvm::find_if(regions, [&](Region &r) { + return &r == target; + }); + return std::distance(begin, it); + }; + return get_region_position(parent, x) < + get_region_position(parent, y); + }; + return !compareRegions(bRegion, aRegion); + } + aAncestor = parent; + } + + llvm_unreachable("Operations do not share a common ancestor"); + //// Recursive case: compare parent operations + //return comesBefore(aParent, bParent); +} + std::vector getSortedUsers(Value val) { std::vector users; for (Operation *user : val.getUsers()) { users.push_back(user); } + //TODO: problem is this only works for 1 level // Sort the users based on their topological order std::sort(users.begin(), users.end(), [](Operation *a, Operation *b) { - return a->isBeforeInBlock(b); + return comesBefore(a,b); + //if (a->getBlock() == b->getBlock()) { + // return a->isBeforeInBlock(b); + //} + //if (a->getParentRegion() == b->getParentRegion()) { + // Block *blockA = a->getBlock(); + // Block *blockB = b->getBlock(); + // return std::distance(blockA->getParent()->begin(), blockA->getIterator()) < + // std::distance(blockB->getParent()->begin(), blockB->getIterator()); + //} + + //return a->getParentRegion()->isAncestor(b->getParentRegion()); }); return users; @@ -70,6 +269,27 @@ std::vector getSortedUsers(Operation *op) { return sortedUsers; } +Region* findCommonAncestorRegion(Operation* a, Operation* b) { + DenseMap regionCounts; + + // Walk up from operation A + Operation* currentOp = a; + while (Region* region = currentOp->getParentRegion()) { + regionCounts[region]++; + currentOp = region->getParentOp(); + } + + // Walk up from operation B to find common region + currentOp = b; + while (Region* region = currentOp->getParentRegion()) { + if (regionCounts.count(region)) + return region; + currentOp = region->getParentOp(); + } + return nullptr; +} + + struct debufferizationAllocaRemoval : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -109,6 +329,20 @@ struct debufferizationAllocaRemoval : public OpRewritePattern } }; +// Problems with this implementation: The way this implementation works is by jumping over users +// of alloca/args. The users we get are not in sorted order. We write a function to sort out the users across +// regions, blocks and ops as long as they lie in the same ancestry. +// Now as we update an op, and use the output tensor to give input to the next op- it works fine for simple cases with no region. +// But things becomes more complicated when we have nested regions like in scf.if and scf.for ops +// Why? Because we need to update scf.if and scf.for ops to yield correct tensors to be used by the next user. +// So how to do it? Well the best way is to traverse all the IR in a walk and and as we encouter a user and it's linalg.generic then we update +// it's params to tensor and generate an output tensor if it can, and move to the next op and repeat this until we encounter an end of region. +// At this point we need to decide if we need to yield the tensor or not? This depends if there is an external user of the original arg/alloca +// still left over. I think this can be done by tracking users of an op, and eliminating the ones which have been used. +// In the current way it's done- we can go the next user and check if the previous user is in the same block if not we need to propagate the previous +// users output tensor through regions with yield. +// How does this work if the user is not actually outputing data, that means it didn't generate an output tensor. In which case the original tensor needs to be continued. +// In current flow, we are tracking updated output tensor, now we can iteratively yield the value until it reaches the same block as next user. struct LinalgDebufferization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -153,7 +387,7 @@ struct LinalgDebufferization : public OpRewritePattern { // if we are no alias we can just look at all users of the value // if we are not noalias, or we are captured, then we have to look at all users that // could read or write - if (!isNoalias) { //|| isCaptured(memVal)) { TODO: need to improve isCaptured to include linalg.generic + if ((!isNoalias) || isCaptured(memVal)) { //TODO: need to improve isCaptured to include linalg.generic return failure(); //|| isCaptured(memVal)) { TODO: need to improve isCaptured to include linalg.generic } @@ -185,6 +419,7 @@ struct LinalgDebufferization : public OpRewritePattern { auto toTensorOp = rewriter.create( memVal.getLoc(), tensorType, memVal); Value currentTensor = toTensorOp; + Value prevTensor = toTensorOp; auto sortedUsers = getSortedUsers(memVal); @@ -202,6 +437,82 @@ struct LinalgDebufferization : public OpRewritePattern { SmallVector resultTypes; // Create a new linalg.generic in Destination Style Passing format + //check_if_current_tensor_is_available_to_user_if_not_propagate_to_scope() { + // extract_common_ancestor of curentTensor and userOp. + // propagte currentTensor all the way to common ancestor. + // Make the propagated value the current tensor. + //} + auto commonRegion = findCommonAncestorRegion(currentTensor.getDefiningOp(), user); + if (!commonRegion) return failure(); + // Collect regions from source to common ancestor + SmallVector regions; + for (Region* r = currentTensor.getParentRegion(); r != commonRegion; + r = r->getParentOp()->getParentRegion()) { + regions.push_back(r); + } + + // Propagate value through each region + Value currentValue = currentTensor; + for (Region* region : llvm::reverse(regions)) { + Block& block = region->front(); + Operation* terminator = block.getTerminator(); + Operation *parentOp = region->getParentOp(); + + if( auto prevIf = dyn_cast_or_null(parentOp)) { + auto prevResults = prevIf.getResults(); + SmallVector newResultTypes; + for (auto res : prevResults) + newResultTypes.push_back(res.getType()); + newResultTypes.push_back(currentValue.getType()); + + // Yield original results + new value + auto thenYieldArgs = prevIf.thenYield().getOperands(); + SmallVector thenYieldValues; + for (const auto &it :thenYieldArgs) { + thenYieldValues.push_back(it); + } + thenYieldValues.push_back(currentValue); + + SmallVector elseYieldValues; + if(!prevIf.getElseRegion().empty()){ + auto elseYieldArgs = prevIf.elseYield().getOperands(); + for (const auto &it :elseYieldArgs) { + elseYieldValues.push_back(it); + } + } + elseYieldValues.push_back(prevTensor); + + //Create new Ifop + rewriter.setInsertionPoint(prevIf); + auto newIf = rewriter.create(prevIf.getLoc(), + newResultTypes, // Combined types + prevIf.getCondition(), // New condition value + true + ); + if (newIf.thenBlock()) + rewriter.eraseBlock(newIf.thenBlock()); + + newIf.getThenRegion().takeBody(prevIf.getThenRegion()); + if(!prevIf.getElseRegion().empty()) + newIf.getElseRegion().takeBody(prevIf.getElseRegion()); + + + //Update yield ops + rewriter.setInsertionPointToEnd(newIf.thenBlock()); + rewriter.replaceOpWithNewOp(newIf.thenYield(), thenYieldValues); + if(!prevIf.getElseRegion().empty()) { + rewriter.setInsertionPointToEnd(newIf.elseBlock()); + rewriter.replaceOpWithNewOp(newIf.elseYield(), elseYieldValues); + } else { + rewriter.setInsertionPointToEnd(newIf.elseBlock()); + rewriter.create(newIf.getLoc(), elseYieldValues); + } + + currentValue = newIf->getResult(newIf->getNumResults() - 1); + } + } + currentTensor = currentValue; + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); for (auto input : genericOp.getInputs()) { newInputs.push_back(input == memVal ? currentTensor : input); @@ -220,6 +531,7 @@ struct LinalgDebufferization : public OpRewritePattern { index++; } + rewriter.setInsertionPointAfter(genericOp); StringAttr empty = StringAttr::get(genericOp.getContext()); ArrayRef resultTypesRef(resultTypes); auto newGenericOp = rewriter.create( @@ -239,14 +551,16 @@ struct LinalgDebufferization : public OpRewritePattern { } // Delete the original genericOp - if (newCurrentTensorIndex != -1) + if (newCurrentTensorIndex != -1){ + prevTensor = currentTensor; currentTensor = newGenericOp.getResult(newCurrentTensorIndex); + } processedGenericOps.insert(genericOp.getOperation()); // Delete the original genericOp - //genericOp.erase(); + genericOp.erase(); //WalkResult::interrupt(); - opsToDelete.push_back(genericOp.getOperation()); + //opsToDelete.push_back(genericOp.getOperation()); } } @@ -259,19 +573,17 @@ struct LinalgDebufferization : public OpRewritePattern { bool changed; - do { - changed = funcOp.walk([&](memref::AllocaOp alloca) { - //if (handleMemref(alloca.getResult()).succeeded()) - // return WalkResult::advance(); - //return WalkResult::interrupt(); - handleMemref(alloca.getResult()).succeeded(); - return WalkResult::advance(); - }).wasInterrupted(); - - if (changed) - passResult = success(); - } while (changed); + //Fix instead of walk, just get the list of allocaOp users, so that you can easily delete ops inside + SmallVector listOfAllocaOps; + + funcOp.walk([&](memref::AllocaOp alloca) { + listOfAllocaOps.push_back(alloca); + }); + for (auto alloca : listOfAllocaOps) { + handleMemref(alloca); + } + if (llvm::any_of(llvm::map_range(funcOp.getArguments(), handleMemref), [](LogicalResult res) {return res.succeeded();})) passResult = success(); diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index 5490c75bd86f..34e203b9dbb6 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -40,68 +40,160 @@ } } - module @conv_2 { - func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { - ^bb0(%in: i32, %in_0: i32, %out: i32): - %3 = arith.muli %in, %in_0 : i32 - %4 = arith.addi %out, %3 : i32 - linalg.yield %4 : i32 - } - return %c0_i32 : i32 - } + module @in_place_cond_add{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + return + } } - module @harris_score_with_gradient_extra_kernel { - //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> - //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> - //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> - func.func @main(%0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c4_i32 = arith.constant 4 : i32 - %c0_i32 = arith.constant 0 : i32 - %alloca = memref.alloca() : memref<512x512xi32> - %alloca_0 = memref.alloca() : memref<512x512xi32> - %alloca_1 = memref.alloca() : memref<512x512xi32> - %alloca_2 = memref.alloca() : memref<516x516xi32> - %alloca_3 = memref.alloca() : memref<516x516xi32> - %alloca_4 = memref.alloca() : memref<518x518xi32> - //%score = memref.alloca() : memref<512x512xi32> - //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> - //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> - //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> - linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { - ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): - %4 = arith.muli %in, %in_5 : i32 - %5 = arith.addi %out_7, %4 : i32 - %6 = arith.muli %in, %in_6 : i32 - %7 = arith.addi %out, %6 : i32 - linalg.yield %7, %5 : i32, i32 - } - linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { - ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): - %4 = arith.muli %in, %in : i32 - %5 = arith.muli %4, %in_6 : i32 - %6 = arith.addi %out_8, %5 : i32 - %7 = arith.muli %in_5, %in_5 : i32 - %8 = arith.muli %7, %in_6 : i32 - %9 = arith.addi %out_7, %8 : i32 - %10 = arith.muli %in, %in_5 : i32 - %11 = arith.muli %10, %in_6 : i32 - %12 = arith.addi %out, %11 : i32 - linalg.yield %12, %9, %6 : i32, i32, i32 - } - linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { - ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): - %4 = arith.muli %in, %in_5 : i32 - %5 = arith.muli %in_6, %in_6 : i32 - %6 = arith.subi %4, %5 : i32 - %7 = arith.addi %in, %in_5 : i32 - %8 = arith.muli %7, %c4_i32 : i32 - %9 = arith.muli %8, %7 : i32 - %10 = arith.subi %6, %9 : i32 - linalg.yield %10 : i32 + module @in_place_add_for{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.for %i = %c0 to %c10 step %c1 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + return + } + } + + //Case when buffer is captured + module @in_place_add_for_loop_carried{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + //%buffer = memref.alloca() : memref<128xf32> + %result = scf.for %i = %c0 to %c10 step %c1 iter_args(%buf = %buffer) -> (memref<128xf32>) { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf : memref<128xf32>) + outs(%buf : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + scf.yield %buf : memref<128xf32> + } + return + } + } + + module @in_place_cond_add_followed_by_add{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return } - return %c0_i32 : i32 - } - } \ No newline at end of file + } + +// module @conv_2 { +// func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c0_i32 = arith.constant 0 : i32 +// linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { +// ^bb0(%in: i32, %in_0: i32, %out: i32): +// %3 = arith.muli %in, %in_0 : i32 +// %4 = arith.addi %out, %3 : i32 +// linalg.yield %4 : i32 +// } +// return %c0_i32 : i32 +// } +// } + +// module @harris_score_with_gradient_extra_kernel { +// //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> +// //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> +// //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> +// func.func @main(%0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c4_i32 = arith.constant 4 : i32 +// %c0_i32 = arith.constant 0 : i32 +// %alloca = memref.alloca() : memref<512x512xi32> +// %alloca_0 = memref.alloca() : memref<512x512xi32> +// %alloca_1 = memref.alloca() : memref<512x512xi32> +// %alloca_2 = memref.alloca() : memref<516x516xi32> +// %alloca_3 = memref.alloca() : memref<516x516xi32> +// %alloca_4 = memref.alloca() : memref<518x518xi32> +// //%score = memref.alloca() : memref<512x512xi32> +// //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> +// //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> +// //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> +// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): +// %4 = arith.muli %in, %in_5 : i32 +// %5 = arith.addi %out_7, %4 : i32 +// %6 = arith.muli %in, %in_6 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7, %5 : i32, i32 +// } +// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): +// %4 = arith.muli %in, %in : i32 +// %5 = arith.muli %4, %in_6 : i32 +// %6 = arith.addi %out_8, %5 : i32 +// %7 = arith.muli %in_5, %in_5 : i32 +// %8 = arith.muli %7, %in_6 : i32 +// %9 = arith.addi %out_7, %8 : i32 +// %10 = arith.muli %in, %in_5 : i32 +// %11 = arith.muli %10, %in_6 : i32 +// %12 = arith.addi %out, %11 : i32 +// linalg.yield %12, %9, %6 : i32, i32, i32 +// } +// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { +// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): +// %4 = arith.muli %in, %in_5 : i32 +// %5 = arith.muli %in_6, %in_6 : i32 +// %6 = arith.subi %4, %5 : i32 +// %7 = arith.addi %in, %in_5 : i32 +// %8 = arith.muli %7, %c4_i32 : i32 +// %9 = arith.muli %8, %7 : i32 +// %10 = arith.subi %6, %9 : i32 +// linalg.yield %10 : i32 +// } +// return %c0_i32 : i32 +// } +// } \ No newline at end of file From 4a7efe78d132f0b8ed49b8a30201b86728ea174e Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 31 Jan 2025 13:47:19 -0800 Subject: [PATCH 41/77] Bug fix for erasing the op correctly --- lib/polygeist/Passes/LinalgDebufferize.cpp | 162 ++++++--------------- test/polygeist-opt/debufferize.mlir | 127 ++++++++-------- 2 files changed, 109 insertions(+), 180 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 7c2a57405d8e..0590f710db3d 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -51,76 +51,6 @@ bool comesBefore(Operation *a, Operation *b) { if (isAncestor(a, b)) return true; if (isAncestor(b, a)) return false; - //Block *aBlock = a->getBlock(); - //Block *bBlock = b->getBlock(); - - //// Same block: compare operation order - //if (aBlock == bBlock) { - // for (Operation &op : aBlock->getOperations()) { - // if (&op == a) return true; - // if (&op == b) return false; - // } - // llvm_unreachable("Operations not found in their parent block"); - //} - - //// Different blocks: compare region hierarchy - //Region *aRegion = aBlock->getParent(); - //Region *bRegion = bBlock->getParent(); - - //// Same region: compare block order - //if (aRegion == bRegion) { - // //auto aBlockIt = std::find(aRegion->begin(), aRegion->end(), aBlock); - // //auto bBlockIt = std::find(aRegion->begin(), aRegion->end(), bBlock); - // //return aBlockIt < bBlockIt; - // //const int aIndex = std::distance(aRegion->begin(), aRegion->find(aBlock)); - // //const int bIndex = std::distance(aRegion->begin(), aRegion->find(bBlock)); - // //return aIndex < bIndex; - // auto get_block_pos = [](Region *region, Block *block) { - // auto &blocks = region->getBlocks(); - // auto it = llvm::find_if(blocks, [block](Block &b) { - // return &b == block; // Address comparison - // }); - // assert(it != blocks.end() && "Block not found in region"); - // return std::distance(blocks.begin(), it); - // //return std::distance(region->getBlocks().begin(), - // // llvm::find(region->getBlocks(), block)); - // }; - // return get_block_pos(aRegion, aBlock) < - // get_block_pos(aRegion, bBlock); - //} - - //// Different regions: compare parent operations - //Operation *aParent = aRegion->getParentOp(); - //Operation *bParent = bRegion->getParentOp(); - - //// Same parent op: compare region order - //if (aParent == bParent) { - // //auto aRegionIt = std::find(aParent->getRegions().begin(), - // // aParent->getRegions().end(), aRegion); - // //auto bRegionIt = std::find(bParent->getRegions().begin(), - // // bParent->getRegions().end(), bRegion); - // //return aRegionIt < bRegionIt; - // //auto get_region_position = [](Operation *parent, Region *target) { - // //return std::distance( - // // parent->getRegions.begin(), - // // llvm::find_if(parent->getRegions(), [&](Region &r) { - // // return &r == target; // Compare region addresses - // // }) - // // ); - // //}; - - // auto get_region_position = [](Operation *parent, Region *target) { - // auto regions = parent->getRegions(); // Get reference to region list - // auto begin = regions.begin(); - // auto it = llvm::find_if(regions, [&](Region &r) { - // return &r == target; - // }); - // return std::distance(begin, it); - // }; - // return get_region_position(aParent, aRegion) < - // get_region_position(aParent, bRegion); - //} - Operation *aParent = a->getParentOp(); Operation *bParent = b->getParentOp(); // Walk up b's hierarchy until we reach a's level @@ -224,50 +154,42 @@ bool comesBefore(Operation *a, Operation *b) { std::vector getSortedUsers(Value val) { std::vector users; for (Operation *user : val.getUsers()) { - users.push_back(user); + auto it = std::find_if(users.begin(), users.end(), + [user](const Operation* op) { + return op == user; + }); + if(it == users.end()) + users.push_back(user); } - //TODO: problem is this only works for 1 level - // Sort the users based on their topological order std::sort(users.begin(), users.end(), [](Operation *a, Operation *b) { return comesBefore(a,b); - //if (a->getBlock() == b->getBlock()) { - // return a->isBeforeInBlock(b); - //} - //if (a->getParentRegion() == b->getParentRegion()) { - // Block *blockA = a->getBlock(); - // Block *blockB = b->getBlock(); - // return std::distance(blockA->getParent()->begin(), blockA->getIterator()) < - // std::distance(blockB->getParent()->begin(), blockB->getIterator()); - //} - - //return a->getParentRegion()->isAncestor(b->getParentRegion()); }); return users; } -std::vector getSortedUsers(Operation *op) { - // Find the parent function - auto funcOp = op->getParentOfType(); - if (!funcOp) - return {}; +// std::vector getSortedUsers(Operation *op) { +// // Find the parent function +// auto funcOp = op->getParentOfType(); +// if (!funcOp) +// return {}; - // Map to store order of operations - llvm::DenseMap opOrder; - size_t order = 0; +// // Map to store order of operations +// llvm::DenseMap opOrder; +// size_t order = 0; - funcOp.walk([&](Operation *curOp) { opOrder[curOp] = order++; }); +// funcOp.walk([&](Operation *curOp) { opOrder[curOp] = order++; }); - std::vector sortedUsers(op->getUsers().begin(), - op->getUsers().end()); +// std::vector sortedUsers(op->getUsers().begin(), +// op->getUsers().end()); - std::sort( - sortedUsers.begin(), sortedUsers.end(), - [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; }); +// std::sort( +// sortedUsers.begin(), sortedUsers.end(), +// [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; }); - return sortedUsers; -} +// return sortedUsers; +// } Region* findCommonAncestorRegion(Operation* a, Operation* b) { DenseMap regionCounts; @@ -351,15 +273,15 @@ struct LinalgDebufferization : public OpRewritePattern { auto module = funcOp->getParentOfType(); - SmallVector opsToDelete; - llvm::SmallPtrSet opsToDeleteSet; + //SmallVector opsToDelete; + //llvm::SmallPtrSet opsToDeleteSet; // Tracks both old linalg.generics and linalg.generics with repeated values // in ins and outs - llvm::SmallPtrSet processedGenericOps; LogicalResult passResult = failure(); auto handleMemref = [&](Value memVal) -> LogicalResult { + llvm::SmallPtrSet processedGenericOps; auto module = memVal.getParentRegion()->getParentOfType(); if (!memVal.getType().isa()) { @@ -428,8 +350,8 @@ struct LinalgDebufferization : public OpRewritePattern { if (auto genericOp = dyn_cast(user)) { // auto genericOp = cast(user); - if (processedGenericOps.count(genericOp) > 0) - continue; + //if (processedGenericOps.count(genericOp) > 0) + // continue; rewriter.setInsertionPointAfter(genericOp); SmallVector newInputs; @@ -556,17 +478,22 @@ struct LinalgDebufferization : public OpRewritePattern { currentTensor = newGenericOp.getResult(newCurrentTensorIndex); } - processedGenericOps.insert(genericOp.getOperation()); + //processedGenericOps.insert(genericOp.getOperation()); // Delete the original genericOp - genericOp.erase(); + //unsigned numUsers = std::distance(genericOp.getResults().getUsers().begin(), genericOp.getResults().getUsers().end()); + //llvm::outs() << "Number of generic op uses: " << numUsers << "\n"; + //genericOp.erase(); + rewriter.eraseOp(genericOp); //WalkResult::interrupt(); //opsToDelete.push_back(genericOp.getOperation()); } } - - auto toMemrefOp = rewriter.create( - memVal.getLoc(), memrefType, currentTensor); - rewriter.create(memVal.getLoc(), toMemrefOp, memVal); + + //if(currentTensor != prevTensor) { + auto toMemrefOp = rewriter.create( + memVal.getLoc(), memrefType, currentTensor); + rewriter.create(memVal.getLoc(), toMemrefOp, memVal); + //} // opsToDelete.push_back(allocaOp.getOperation()); return success(); }; @@ -584,13 +511,15 @@ struct LinalgDebufferization : public OpRewritePattern { handleMemref(alloca); } - if (llvm::any_of(llvm::map_range(funcOp.getArguments(), handleMemref), [](LogicalResult res) {return res.succeeded();})) + for(auto arg: funcOp.getArguments()){ + handleMemref(arg); + } passResult = success(); - for (Operation *op : opsToDelete) { - op->erase(); - } - opsToDelete.clear(); + //for (Operation *op : opsToDelete) { + // op->erase(); + //} + //opsToDelete.clear(); return passResult; } @@ -603,6 +532,7 @@ struct LinalgDebufferize : public LinalgDebufferizeBase { } // namespace void LinalgDebufferize::runOnOperation() { + auto module = getOperation()->getParentOfType(); RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index 34e203b9dbb6..bd28c13d7c51 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -132,68 +132,67 @@ } } -// module @conv_2 { -// func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c0_i32 = arith.constant 0 : i32 -// linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { -// ^bb0(%in: i32, %in_0: i32, %out: i32): -// %3 = arith.muli %in, %in_0 : i32 -// %4 = arith.addi %out, %3 : i32 -// linalg.yield %4 : i32 -// } -// return %c0_i32 : i32 -// } -// } + module @conv_2 { + func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %3 = arith.muli %in, %in_0 : i32 + %4 = arith.addi %out, %3 : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } + } -// module @harris_score_with_gradient_extra_kernel { -// //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> -// //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> -// //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> -// func.func @main(%0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c4_i32 = arith.constant 4 : i32 -// %c0_i32 = arith.constant 0 : i32 -// %alloca = memref.alloca() : memref<512x512xi32> -// %alloca_0 = memref.alloca() : memref<512x512xi32> -// %alloca_1 = memref.alloca() : memref<512x512xi32> -// %alloca_2 = memref.alloca() : memref<516x516xi32> -// %alloca_3 = memref.alloca() : memref<516x516xi32> -// %alloca_4 = memref.alloca() : memref<518x518xi32> -// //%score = memref.alloca() : memref<512x512xi32> -// //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> -// //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> -// //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> -// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): -// %4 = arith.muli %in, %in_5 : i32 -// %5 = arith.addi %out_7, %4 : i32 -// %6 = arith.muli %in, %in_6 : i32 -// %7 = arith.addi %out, %6 : i32 -// linalg.yield %7, %5 : i32, i32 -// } -// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): -// %4 = arith.muli %in, %in : i32 -// %5 = arith.muli %4, %in_6 : i32 -// %6 = arith.addi %out_8, %5 : i32 -// %7 = arith.muli %in_5, %in_5 : i32 -// %8 = arith.muli %7, %in_6 : i32 -// %9 = arith.addi %out_7, %8 : i32 -// %10 = arith.muli %in, %in_5 : i32 -// %11 = arith.muli %10, %in_6 : i32 -// %12 = arith.addi %out, %11 : i32 -// linalg.yield %12, %9, %6 : i32, i32, i32 -// } -// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): -// %4 = arith.muli %in, %in_5 : i32 -// %5 = arith.muli %in_6, %in_6 : i32 -// %6 = arith.subi %4, %5 : i32 -// %7 = arith.addi %in, %in_5 : i32 -// %8 = arith.muli %7, %c4_i32 : i32 -// %9 = arith.muli %8, %7 : i32 -// %10 = arith.subi %6, %9 : i32 -// linalg.yield %10 : i32 -// } -// return %c0_i32 : i32 -// } -// } \ No newline at end of file + module @harris_score_with_gradient_extra_kernel { + //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + func.func @main(%input: memref<518x518xi32>, %0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + //%score = memref.alloca() : memref<512x512xi32> + //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.addi %out_7, %4 : i32 + %6 = arith.muli %in, %in_6 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7, %5 : i32, i32 + } + linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): + %4 = arith.muli %in, %in : i32 + %5 = arith.muli %4, %in_6 : i32 + %6 = arith.addi %out_8, %5 : i32 + %7 = arith.muli %in_5, %in_5 : i32 + %8 = arith.muli %7, %in_6 : i32 + %9 = arith.addi %out_7, %8 : i32 + %10 = arith.muli %in, %in_5 : i32 + %11 = arith.muli %10, %in_6 : i32 + %12 = arith.addi %out, %11 : i32 + linalg.yield %12, %9, %6 : i32, i32, i32 + } + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.muli %in_6, %in_6 : i32 + %6 = arith.subi %4, %5 : i32 + %7 = arith.addi %in, %in_5 : i32 + %8 = arith.muli %7, %c4_i32 : i32 + %9 = arith.muli %8, %7 : i32 + %10 = arith.subi %6, %9 : i32 + linalg.yield %10 : i32 + } + return %c0_i32 : i32 + } + } \ No newline at end of file From 6d8832f150825c38b02ba39a1d0c4be65ea11d1a Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 31 Jan 2025 14:53:20 -0800 Subject: [PATCH 42/77] Bug fixes for 1. recursive parent search in sorting users 2. traversing regions to propagate values in correct order --- lib/polygeist/Passes/LinalgDebufferize.cpp | 6 +- test/polygeist-opt/debufferize.mlir | 75 ++++++++++++++++++++++ 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 0590f710db3d..64358256d7df 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -146,9 +146,9 @@ bool comesBefore(Operation *a, Operation *b) { aAncestor = parent; } - llvm_unreachable("Operations do not share a common ancestor"); + //llvm_unreachable("Operations do not share a common ancestor"); //// Recursive case: compare parent operations - //return comesBefore(aParent, bParent); + return comesBefore(aParent, bParent); } std::vector getSortedUsers(Value val) { @@ -375,7 +375,7 @@ struct LinalgDebufferization : public OpRewritePattern { // Propagate value through each region Value currentValue = currentTensor; - for (Region* region : llvm::reverse(regions)) { + for (Region* region : regions) { Block& block = region->front(); Operation* terminator = block.getTerminator(); Operation *parentOp = region->getParentOp(); diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index bd28c13d7c51..183e81d98489 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -132,6 +132,81 @@ } } + module @in_place_cond_add_followed_by_add2{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1, %cond2: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond2 { + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + + module @in_place_cond_add_followed_by_add3{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1, %cond2: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond2 { + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + } + scf.if %cond2 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + %sum3 = arith.addf %sum2, %value : f32 + linalg.yield %sum3 : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + module @conv_2 { func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { %c0_i32 = arith.constant 0 : i32 From 6ca2aebb6dd1e5bd16491fb827460a4a099c6f9c Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 31 Jan 2025 15:58:01 -0800 Subject: [PATCH 43/77] Added cases of buffer capture which doesn't debufferize --- lib/polygeist/Passes/LinalgDebufferize.cpp | 7 ++- test/polygeist-opt/debufferize.mlir | 63 ++++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 64358256d7df..5e7b6e1c1a98 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -309,8 +309,8 @@ struct LinalgDebufferization : public OpRewritePattern { // if we are no alias we can just look at all users of the value // if we are not noalias, or we are captured, then we have to look at all users that // could read or write - if ((!isNoalias) || isCaptured(memVal)) { //TODO: need to improve isCaptured to include linalg.generic - return failure(); //|| isCaptured(memVal)) { TODO: need to improve isCaptured to include linalg.generic + if ((!isNoalias) || isCaptured(memVal)) { + return failure(); } MemRefType memrefType; @@ -432,6 +432,9 @@ struct LinalgDebufferization : public OpRewritePattern { currentValue = newIf->getResult(newIf->getNumResults() - 1); } + // else if( auto prevFor = dyn_cast_or_null(parentOp)) { + + // } } currentTensor = currentValue; diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index 183e81d98489..ffe31157e7c0 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -80,6 +80,7 @@ } } + //TODO: not debufferized //Case when buffer is captured module @in_place_add_for_loop_carried{ func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { @@ -103,6 +104,68 @@ } } + //TODO: not debufferized + module @in_place_add_for_loop_carried2{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %buffer2 = memref.alloca() : memref<128xf32> + %result:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%buf = %buffer, %buf2 = %buffer2) -> (memref<128xf32>, memref<128xf32>) { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf : memref<128xf32>) + outs(%buf2 : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf2 : memref<128xf32>) + outs(%buf : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + scf.yield %buf, %buf2 : memref<128xf32>, memref<128xf32> + } + return + } + } + + module @cross_buffer_add{ + func.func @in_place_add(%buf: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %buf2 = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf : memref<128xf32>) + outs(%buf2 : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf2 : memref<128xf32>) + outs(%buf : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + return + } + } + module @in_place_cond_add_followed_by_add{ func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { %c0 = arith.constant 0 : index From 803ec30c8b53d58996cb25882668bc8d9e43f713 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 31 Jan 2025 16:14:29 -0800 Subject: [PATCH 44/77] Canonicalization gets rid of memref capture by loop --- lib/polygeist/Passes/LinalgDebufferize.cpp | 3 --- test/polygeist-opt/debufferize.mlir | 4 +--- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 5e7b6e1c1a98..2d412df21328 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -432,9 +432,6 @@ struct LinalgDebufferization : public OpRewritePattern { currentValue = newIf->getResult(newIf->getNumResults() - 1); } - // else if( auto prevFor = dyn_cast_or_null(parentOp)) { - - // } } currentTensor = currentValue; diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index ffe31157e7c0..4d582dced9e8 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -1,4 +1,4 @@ -//polygeist-opt --linalg-debufferize debufferize.mlir +//polygeist-opt --canonicalize --linalg-debufferize --canonicalize debufferize.mlir #map16 = affine_map<(d0, d1, d2) -> (d2, d1)> #map17 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> @@ -80,7 +80,6 @@ } } - //TODO: not debufferized //Case when buffer is captured module @in_place_add_for_loop_carried{ func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { @@ -104,7 +103,6 @@ } } - //TODO: not debufferized module @in_place_add_for_loop_carried2{ func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { %c0 = arith.constant 0 : index From fb0ac185fdcb198fa380c4e9df9a67887f2ee5de Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 6 Feb 2025 21:27:06 -0800 Subject: [PATCH 45/77] Working implementation for scf.for op and scf.if op; added bug fix to propagate values to the top region before bufferization.to_memref --- lib/polygeist/Passes/LinalgDebufferize.cpp | 261 ++++++++++++++++----- 1 file changed, 198 insertions(+), 63 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 2d412df21328..0616b85963fd 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -250,6 +250,181 @@ struct debufferizationAllocaRemoval : public OpRewritePattern return success(); } }; + +void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVector regions, PatternRewriter &rewriter) { + auto module = currentValue.getDefiningOp()->getParentOfType(); + for (Region* region : regions) { + Block& block = region->front(); + Operation* terminator = block.getTerminator(); + Operation *parentOp = region->getParentOp(); + + if( auto prevIf = dyn_cast_or_null(parentOp)) { + auto prevResults = prevIf.getResults(); + SmallVector newResultTypes; + for (auto res : prevResults) + newResultTypes.push_back(res.getType()); + newResultTypes.push_back(currentValue.getType()); + + // Yield original results + new value + auto thenYieldArgs = prevIf.thenYield().getOperands(); + SmallVector thenYieldValues; + for (const auto &it :thenYieldArgs) { + thenYieldValues.push_back(it); + } + thenYieldValues.push_back(currentValue); + + SmallVector elseYieldValues; + if(!prevIf.getElseRegion().empty()){ + auto elseYieldArgs = prevIf.elseYield().getOperands(); + for (const auto &it :elseYieldArgs) { + elseYieldValues.push_back(it); + } + } + elseYieldValues.push_back(prevTensor);//TODO: Need to replace this with earliest use of op in the + // given region, prevTensor doesn't work - since this won't work for a chain of connected ops. + + //Create new Ifop + rewriter.setInsertionPoint(prevIf); + auto newIf = rewriter.create(prevIf.getLoc(), + newResultTypes, // Combined types + prevIf.getCondition(), // New condition value + true + ); + if (newIf.thenBlock()) + rewriter.eraseBlock(newIf.thenBlock()); + + newIf.getThenRegion().takeBody(prevIf.getThenRegion()); + if(!prevIf.getElseRegion().empty()) + newIf.getElseRegion().takeBody(prevIf.getElseRegion()); + + + //Update yield ops + rewriter.setInsertionPointToEnd(newIf.thenBlock()); + rewriter.replaceOpWithNewOp(newIf.thenYield(), thenYieldValues); + if(!prevIf.getElseRegion().empty()) { + rewriter.setInsertionPointToEnd(newIf.elseBlock()); + rewriter.replaceOpWithNewOp(newIf.elseYield(), elseYieldValues); + } else { + rewriter.setInsertionPointToEnd(newIf.elseBlock()); + rewriter.create(newIf.getLoc(), elseYieldValues); + } + + //TODO: need to update results of prevIf and else with the new ones + currentValue = newIf->getResult(newIf->getNumResults() - 1); + } + else if (auto prevFor = dyn_cast_or_null(parentOp)) { + SmallVector newInitOperands = prevFor.getInitArgs(); + newInitOperands.push_back(prevTensor); //Needs to be the earliest use inside the region. + //TODO: Does this require fix in if as well? + + SmallVector newResultTypes(prevFor.getResultTypes().begin(), prevFor.getResultTypes().end()); + newResultTypes.push_back(currentValue.getType()); + + rewriter.setInsertionPoint(prevFor); + scf::ForOp newLoop = rewriter.create( + prevFor.getLoc(), + prevFor.getLowerBound(), + prevFor.getUpperBound(), + prevFor.getStep(), + newInitOperands + ); + newLoop->setAttrs(prevFor.getOperation()->getAttrs()); + + // Create block with induction variable + original args + new arg + SmallVector blockArgTypes; + blockArgTypes.push_back(newLoop.getInductionVar().getType()); // IV + llvm::append_range(blockArgTypes, newLoop.getResultTypes()); // Original args + //blockArgTypes.push_back(prevTensor.getType()); // New arg + + //Block *newBlock = rewriter.createBlock( + // &newLoop.getRegion(), + // newLoop.getRegion().end(), + // blockArgTypes, + // {newLoop.getLoc(), newLoop.getLoc()} // Locations + //); + + //rewriter.inlineRegionBefore( + // prevFor.getRegion(), + // newLoop.getRegion(), + // newLoop.getRegion().end() + //); + + // Transfer operations from original block to new block + Block *newBlock = &newLoop.getRegion().front(); + Block *originalBlock = &prevFor.getRegion().front(); + newBlock->getOperations().splice( + newBlock->end(), + originalBlock->getOperations() + ); + + // Replace uses of original block arguments with new ones + for (unsigned i = 0; i < originalBlock->getNumArguments()-1; ++i) { + originalBlock->getArgument(i + 1) // +1 for IV + .replaceAllUsesWith(newBlock->getArgument(i + 1)); + } + + auto yieldOp = cast(newBlock->getTerminator()); + SmallVector newYieldValues = yieldOp.getOperands(); + // Add new iteration arg from block arguments + newYieldValues.push_back(currentValue); + + //OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yieldOp); + + rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); + // rewriter.replaceOp(prevFor, newLoop.getResults()); + //Update block args + //newLoop.getBody()->getArguments().front().replaceAllUsesWith(newLoop.getInductionVar()); + // IRMapping mapper; + // mapper.map(prevFor.getInductionVar(), newLoop.getInductionVar()); + // rewriter.setInsertionPointToStart(newLoop.getBody()); + // for (auto [oldArg, newArg] : llvm::zip(prevFor.getRegionIterArgs(), + // newLoop.getRegionIterArgs().drop_back())) { + // mapper.map(oldArg, newArg); + // } + // //for (unsigned i = 0, e = prevFor.getNumRegionIterArgs(); i < e; ++i) + // // newLoop.getBody()->getArguments()[i + 1].replaceAllUsesWith(newLoop.getRegionIterArg(i)); + + + // //rewriter.inlineRegionBefore(prevFor.getRegion(), newLoop.getRegion(), newLoop.getRegion().end()); + // for (auto &op : prevFor.getBody()->without_terminator()) { + // rewriter.clone(op, mapper); + // } + + // //Update use of new iter arg + // Value newIterArg = newLoop.getRegionIterArgs().back(); + + // auto origYield = cast(prevFor.getBody()->getTerminator()); + // SmallVector newYieldOperands; + // for (Value operand : origYield.getOperands()) { + // newYieldOperands.push_back(mapper.lookupOrDefault(operand)); + // } + // // Add value for new iteration argument + // newYieldOperands.push_back(currentValue); + + // rewriter.setInsertionPointToEnd(newLoop.getBody()); + // rewriter.create(origYield.getLoc(), newYieldOperands); + + // for (auto [oldResult, newResult] : + // llvm::zip(prevFor.getResults(), newLoop.getResults().drop_back())) { + // rewriter.replaceAllUsesWith(oldResult, newResult); + // } + // //auto yieldOp = cast(newLoop.getBody()->getTerminator()); + // //OpBuilder::InsertionGuard g(rewriter); + // //rewriter.setInsertionPoint(yieldOp); + + // //SmallVector newYieldedValues = yieldOp.getResults(); + // //newYieldedValues.push_back(currentValue); + + // //rewriter.replaceOpWithNewOp(yieldOp, newYieldedValues); + // rewriter.replaceOp(prevFor, newLoop.getResults()); + rewriter.eraseOp(prevFor); + //Update the current value + currentValue = newLoop.getResults().back(); + } + } +} + // Problems with this implementation: The way this implementation works is by jumping over users // of alloca/args. The users we get are not in sorted order. We write a function to sort out the users across @@ -374,67 +549,9 @@ struct LinalgDebufferization : public OpRewritePattern { } // Propagate value through each region + //TODO: Need this in function form so we can call this after the loop as well Value currentValue = currentTensor; - for (Region* region : regions) { - Block& block = region->front(); - Operation* terminator = block.getTerminator(); - Operation *parentOp = region->getParentOp(); - - if( auto prevIf = dyn_cast_or_null(parentOp)) { - auto prevResults = prevIf.getResults(); - SmallVector newResultTypes; - for (auto res : prevResults) - newResultTypes.push_back(res.getType()); - newResultTypes.push_back(currentValue.getType()); - - // Yield original results + new value - auto thenYieldArgs = prevIf.thenYield().getOperands(); - SmallVector thenYieldValues; - for (const auto &it :thenYieldArgs) { - thenYieldValues.push_back(it); - } - thenYieldValues.push_back(currentValue); - - SmallVector elseYieldValues; - if(!prevIf.getElseRegion().empty()){ - auto elseYieldArgs = prevIf.elseYield().getOperands(); - for (const auto &it :elseYieldArgs) { - elseYieldValues.push_back(it); - } - } - elseYieldValues.push_back(prevTensor); - - //Create new Ifop - rewriter.setInsertionPoint(prevIf); - auto newIf = rewriter.create(prevIf.getLoc(), - newResultTypes, // Combined types - prevIf.getCondition(), // New condition value - true - ); - if (newIf.thenBlock()) - rewriter.eraseBlock(newIf.thenBlock()); - - newIf.getThenRegion().takeBody(prevIf.getThenRegion()); - if(!prevIf.getElseRegion().empty()) - newIf.getElseRegion().takeBody(prevIf.getElseRegion()); - - - //Update yield ops - rewriter.setInsertionPointToEnd(newIf.thenBlock()); - rewriter.replaceOpWithNewOp(newIf.thenYield(), thenYieldValues); - if(!prevIf.getElseRegion().empty()) { - rewriter.setInsertionPointToEnd(newIf.elseBlock()); - rewriter.replaceOpWithNewOp(newIf.elseYield(), elseYieldValues); - } else { - rewriter.setInsertionPointToEnd(newIf.elseBlock()); - rewriter.create(newIf.getLoc(), elseYieldValues); - } - - currentValue = newIf->getResult(newIf->getNumResults() - 1); - } - } - currentTensor = currentValue; - + propagateValueThroughRegion(currentTensor, prevTensor, regions, rewriter); ArrayAttr indexingMaps = genericOp.getIndexingMaps(); for (auto input : genericOp.getInputs()) { newInputs.push_back(input == memVal ? currentTensor : input); @@ -489,10 +606,28 @@ struct LinalgDebufferization : public OpRewritePattern { } } + //For adding yields for the last use all the way to the outer most region + auto commonRegion = findCommonAncestorRegion(currentTensor.getDefiningOp(), toTensorOp); + if (!commonRegion) return failure(); + // Collect regions from source to common ancestor + SmallVector regions; + for (Region* r = currentTensor.getParentRegion(); r != commonRegion; + r = r->getParentOp()->getParentRegion()) { + regions.push_back(r); + } + + propagateValueThroughRegion(currentTensor, prevTensor, regions, rewriter); + + //if(!regions.empty()) { + // auto lastRegion = regions.back(); + // Operation *parentOp = lastRegion->getParentOp(); + // rewriter.setInsertionPointAfter(parentOp); + //} //if(currentTensor != prevTensor) { - auto toMemrefOp = rewriter.create( - memVal.getLoc(), memrefType, currentTensor); - rewriter.create(memVal.getLoc(), toMemrefOp, memVal); + rewriter.setInsertionPointAfter(currentTensor.getDefiningOp()); + auto toMemrefOp = rewriter.create( + memVal.getLoc(), memrefType, currentTensor); + rewriter.create(memVal.getLoc(), toMemrefOp, memVal); //} // opsToDelete.push_back(allocaOp.getOperation()); return success(); From 0472c34348327997919358f57348ca14de29a18b Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 7 Feb 2025 14:28:29 -0800 Subject: [PATCH 46/77] Added data structures to track expandedUsers that can include for loops and ifs (ops that have regions), this helps in recursive update of region when Linalg generic transformed inside the loop- working for a single loop case --- lib/polygeist/Passes/LinalgDebufferize.cpp | 141 ++++++++++----------- 1 file changed, 67 insertions(+), 74 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 0616b85963fd..89b9f617ba74 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -30,6 +30,7 @@ using namespace linalg; using namespace tensor; using namespace bufferization; +using opTuple = std::tuple; //First: result, Second: prev_tensor ? bool isCaptured(Value v, Operation *potentialUser = nullptr, bool *seenuse = nullptr); @@ -154,6 +155,7 @@ bool comesBefore(Operation *a, Operation *b) { std::vector getSortedUsers(Value val) { std::vector users; for (Operation *user : val.getUsers()) { + //This logic is to prevent duplication of users auto it = std::find_if(users.begin(), users.end(), [user](const Operation* op) { return op == user; @@ -251,13 +253,16 @@ struct debufferizationAllocaRemoval : public OpRewritePattern } }; -void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVector regions, PatternRewriter &rewriter) { +void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVector regions, std::vector expandedUserList, llvm::DenseMap opResultMap, PatternRewriter &rewriter) { auto module = currentValue.getDefiningOp()->getParentOfType(); for (Region* region : regions) { Block& block = region->front(); Operation* terminator = block.getTerminator(); Operation *parentOp = region->getParentOp(); + //Find prevTensor + //Compare use Values with + if( auto prevIf = dyn_cast_or_null(parentOp)) { auto prevResults = prevIf.getResults(); SmallVector newResultTypes; @@ -310,11 +315,31 @@ void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVe } //TODO: need to update results of prevIf and else with the new ones + opResultMap[newIf] = std::make_tuple(newIf->getResult(newIf->getNumResults() - 1), currentValue); currentValue = newIf->getResult(newIf->getNumResults() - 1); + } else if (auto prevFor = dyn_cast_or_null(parentOp)) { + mlir::Value initTensor; + int insertIdx = 0; + int opOperandIdx = 0; + mlir::Operation * earliestUser; + for(auto user: expandedUserList) { + mlir::Region *opRegion = user->getParentRegion(); + if(region->isAncestor(opRegion)) { + //Maintain a map data structure for tracking every user and if they have been processed then the corresponding result + auto it = opResultMap.find(user); + earliestUser = user; + auto keys_value = it->second; + auto op_result = std::get<0>(keys_value); + initTensor = std::get<1>(keys_value); + break; + } + insertIdx++; + } + SmallVector newInitOperands = prevFor.getInitArgs(); - newInitOperands.push_back(prevTensor); //Needs to be the earliest use inside the region. + newInitOperands.push_back(initTensor); //Needs to be the earliest use inside the region. //TODO: Does this require fix in if as well? SmallVector newResultTypes(prevFor.getResultTypes().begin(), prevFor.getResultTypes().end()); @@ -334,20 +359,6 @@ void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVe SmallVector blockArgTypes; blockArgTypes.push_back(newLoop.getInductionVar().getType()); // IV llvm::append_range(blockArgTypes, newLoop.getResultTypes()); // Original args - //blockArgTypes.push_back(prevTensor.getType()); // New arg - - //Block *newBlock = rewriter.createBlock( - // &newLoop.getRegion(), - // newLoop.getRegion().end(), - // blockArgTypes, - // {newLoop.getLoc(), newLoop.getLoc()} // Locations - //); - - //rewriter.inlineRegionBefore( - // prevFor.getRegion(), - // newLoop.getRegion(), - // newLoop.getRegion().end() - //); // Transfer operations from original block to new block Block *newBlock = &newLoop.getRegion().front(); @@ -368,63 +379,32 @@ void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVe // Add new iteration arg from block arguments newYieldValues.push_back(currentValue); - //OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); - // rewriter.replaceOp(prevFor, newLoop.getResults()); - //Update block args - //newLoop.getBody()->getArguments().front().replaceAllUsesWith(newLoop.getInductionVar()); - // IRMapping mapper; - // mapper.map(prevFor.getInductionVar(), newLoop.getInductionVar()); - // rewriter.setInsertionPointToStart(newLoop.getBody()); - // for (auto [oldArg, newArg] : llvm::zip(prevFor.getRegionIterArgs(), - // newLoop.getRegionIterArgs().drop_back())) { - // mapper.map(oldArg, newArg); - // } - // //for (unsigned i = 0, e = prevFor.getNumRegionIterArgs(); i < e; ++i) - // // newLoop.getBody()->getArguments()[i + 1].replaceAllUsesWith(newLoop.getRegionIterArg(i)); - - - // //rewriter.inlineRegionBefore(prevFor.getRegion(), newLoop.getRegion(), newLoop.getRegion().end()); - // for (auto &op : prevFor.getBody()->without_terminator()) { - // rewriter.clone(op, mapper); - // } - - // //Update use of new iter arg - // Value newIterArg = newLoop.getRegionIterArgs().back(); - - // auto origYield = cast(prevFor.getBody()->getTerminator()); - // SmallVector newYieldOperands; - // for (Value operand : origYield.getOperands()) { - // newYieldOperands.push_back(mapper.lookupOrDefault(operand)); - // } - // // Add value for new iteration argument - // newYieldOperands.push_back(currentValue); - - // rewriter.setInsertionPointToEnd(newLoop.getBody()); - // rewriter.create(origYield.getLoc(), newYieldOperands); - // for (auto [oldResult, newResult] : - // llvm::zip(prevFor.getResults(), newLoop.getResults().drop_back())) { - // rewriter.replaceAllUsesWith(oldResult, newResult); - // } - // //auto yieldOp = cast(newLoop.getBody()->getTerminator()); - // //OpBuilder::InsertionGuard g(rewriter); - // //rewriter.setInsertionPoint(yieldOp); - - // //SmallVector newYieldedValues = yieldOp.getResults(); - // //newYieldedValues.push_back(currentValue); - - // //rewriter.replaceOpWithNewOp(yieldOp, newYieldedValues); - // rewriter.replaceOp(prevFor, newLoop.getResults()); + //Update prevTensor to use iter_arg + OpOperand &operand = earliestUser->getOpOperand(opOperandIdx); + Value newValue = newLoop.getRegionIterArg(newLoop.getRegion().front().getNumArguments()-2); //-1 for IV + operand.set(newValue); + rewriter.eraseOp(prevFor); - //Update the current value currentValue = newLoop.getResults().back(); + + //Store this in the user list for this region, need to create a data structure for users + opResultMap[newLoop] = std::make_tuple(currentValue, initTensor); + //Update the user list with the for Loop + expandedUserList.insert(expandedUserList.begin() + insertIdx, newLoop); } } } +bool isDirectUser(Operation *consumer, Operation *producer) { + for (Value operand : consumer->getOperands()) { + if (operand.getDefiningOp() == producer) + return true; + } + return false; +} // Problems with this implementation: The way this implementation works is by jumping over users // of alloca/args. The users we get are not in sorted order. We write a function to sort out the users across @@ -520,7 +500,22 @@ struct LinalgDebufferization : public OpRewritePattern { auto sortedUsers = getSortedUsers(memVal); + //Other algorithm: + // 1. Walk over all ops + // 2. If you find a directUser - function defined then do the things for sortedUsers + // 3. If you encounter region based ops, like scf.for op and scf.if op, then track the + // op to be used for yield in scf.if + // For scf.for track the the op to be used for init, as well as the op to be updated by init. + // Op to be used by yield comes at the end. + // Problem walk.break will break things and won't be able to track recursive stuff - so would have to restart every time! + + //Variables to track results and init value with an operation that has been changed to tensor from memref + llvm::DenseMap opResultMap; + + // Check if allocaOp is an output in current genericOp + std::vector expandedUserList(sortedUsers); + int userIdx = 0; for (auto user : sortedUsers) { if (auto genericOp = dyn_cast(user)) { @@ -550,8 +545,8 @@ struct LinalgDebufferization : public OpRewritePattern { // Propagate value through each region //TODO: Need this in function form so we can call this after the loop as well - Value currentValue = currentTensor; - propagateValueThroughRegion(currentTensor, prevTensor, regions, rewriter); + propagateValueThroughRegion(currentTensor, prevTensor, regions, expandedUserList, opResultMap, rewriter); + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); for (auto input : genericOp.getInputs()) { newInputs.push_back(input == memVal ? currentTensor : input); @@ -592,17 +587,15 @@ struct LinalgDebufferization : public OpRewritePattern { // Delete the original genericOp if (newCurrentTensorIndex != -1){ prevTensor = currentTensor; + opResultMap[newGenericOp] = std::make_tuple(newGenericOp.getResult(newCurrentTensorIndex), currentTensor); currentTensor = newGenericOp.getResult(newCurrentTensorIndex); } - //processedGenericOps.insert(genericOp.getOperation()); - // Delete the original genericOp - //unsigned numUsers = std::distance(genericOp.getResults().getUsers().begin(), genericOp.getResults().getUsers().end()); - //llvm::outs() << "Number of generic op uses: " << numUsers << "\n"; - //genericOp.erase(); rewriter.eraseOp(genericOp); - //WalkResult::interrupt(); - //opsToDelete.push_back(genericOp.getOperation()); + //Updated expanded user list, as this op is deleted + expandedUserList.insert(expandedUserList.begin() + userIdx, newGenericOp); + userIdx++; + expandedUserList.erase(expandedUserList.begin() + userIdx); } } @@ -616,7 +609,7 @@ struct LinalgDebufferization : public OpRewritePattern { regions.push_back(r); } - propagateValueThroughRegion(currentTensor, prevTensor, regions, rewriter); + propagateValueThroughRegion(currentTensor, prevTensor, regions, expandedUserList, opResultMap, rewriter); //if(!regions.empty()) { // auto lastRegion = regions.back(); From 3272f2c408b8ee8e0dd641b48e638b5371919301 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 7 Feb 2025 16:29:49 -0800 Subject: [PATCH 47/77] Added logic in for loop case to find all users of iter_args and update them --- lib/polygeist/Passes/LinalgDebufferize.cpp | 55 ++++++++++++++++++---- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 89b9f617ba74..b576023a38ef 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -252,7 +252,29 @@ struct debufferizationAllocaRemoval : public OpRewritePattern return success(); } }; - + +void findUsersInRegion( + mlir::Value value, + mlir::Region& region, + llvm::SmallVectorImpl& users +) { + for (mlir::Block& block : region) { + for (mlir::Operation& op : block) { + for (mlir::Value operand : op.getOperands()) { + if (operand == value) { + users.push_back(&op); + break; // No need to check other operands for this op + } + } + + // Recursively check all sub-regions of this operation + for (mlir::Region& subRegion : op.getRegions()) { + findUsersInRegion(value, subRegion, users); + } + } + } +} + void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVector regions, std::vector expandedUserList, llvm::DenseMap opResultMap, PatternRewriter &rewriter) { auto module = currentValue.getDefiningOp()->getParentOfType(); for (Region* region : regions) { @@ -322,21 +344,27 @@ void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVe else if (auto prevFor = dyn_cast_or_null(parentOp)) { mlir::Value initTensor; int insertIdx = 0; - int opOperandIdx = 0; - mlir::Operation * earliestUser; + + //Find init Tensor for the given for loop, i.e first match to expanded user list for(auto user: expandedUserList) { mlir::Region *opRegion = user->getParentRegion(); if(region->isAncestor(opRegion)) { //Maintain a map data structure for tracking every user and if they have been processed then the corresponding result auto it = opResultMap.find(user); - earliestUser = user; + if(it == opResultMap.end()) + continue; auto keys_value = it->second; auto op_result = std::get<0>(keys_value); initTensor = std::get<1>(keys_value); break; } + //TODO: Fix this- need to be only updated until we get first region ancestor match insertIdx++; - } + } + + //After first match, now find all the users of the init Tensor in a region. + llvm::SmallVector initOpUsers; + findUsersInRegion(initTensor, *region, initOpUsers); SmallVector newInitOperands = prevFor.getInitArgs(); newInitOperands.push_back(initTensor); //Needs to be the earliest use inside the region. @@ -383,10 +411,21 @@ void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVe rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); //Update prevTensor to use iter_arg - OpOperand &operand = earliestUser->getOpOperand(opOperandIdx); - Value newValue = newLoop.getRegionIterArg(newLoop.getRegion().front().getNumArguments()-2); //-1 for IV - operand.set(newValue); + for(auto initOpUser: initOpUsers) { + // Iterate over all operands (both inputs and outputs) + for (const auto &en : llvm::enumerate(initOpUser->getOperands())) { + if (en.value() == initTensor) { + OpOperand &operand = initOpUser->getOpOperand(en.index()); + Value newValue = newLoop.getRegionIterArg(newLoop.getRegion().front().getNumArguments()-2); //-1 for IV + operand.set(newValue); + } + } + } + //Update users of prev For loops results + for (auto [oldResult, newResult] : llvm::zip(prevFor.getResults(), newLoop.getResults().drop_back())) { + oldResult.replaceAllUsesWith(newResult); + } rewriter.eraseOp(prevFor); currentValue = newLoop.getResults().back(); From da2ae5b89f6c6799429ba93cbab11dbaa5b32fe2 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 7 Feb 2025 17:02:41 -0800 Subject: [PATCH 48/77] Added a bunch of tests with nested regions- all getting connected and debufferized by the debufferization pass --- lib/polygeist/Passes/LinalgDebufferize.cpp | 55 ++++---- test/polygeist-opt/debufferize.mlir | 149 ++++++++++++++++----- 2 files changed, 144 insertions(+), 60 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index b576023a38ef..ce4154a6e6ae 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -275,14 +275,34 @@ void findUsersInRegion( } } -void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVector regions, std::vector expandedUserList, llvm::DenseMap opResultMap, PatternRewriter &rewriter) { +void propagateValueThroughRegion(Value ¤tValue, SmallVector regions, std::vector expandedUserList, llvm::DenseMap opResultMap, PatternRewriter &rewriter) { auto module = currentValue.getDefiningOp()->getParentOfType(); for (Region* region : regions) { Block& block = region->front(); Operation* terminator = block.getTerminator(); Operation *parentOp = region->getParentOp(); - //Find prevTensor + //Find init Tensor for the given for loop, i.e first match to expanded user list + mlir::Value initTensor; + int insertIdx = 0; + bool insertIdxFound = false; + for(auto user: expandedUserList) { + mlir::Region *opRegion = user->getParentRegion(); + if(region->isAncestor(opRegion)) { + insertIdxFound = true; + //Maintain a map data structure for tracking every user and if they have been processed then the corresponding result + auto it = opResultMap.find(user); + if(it == opResultMap.end()) + continue; + auto keys_value = it->second; + auto op_result = std::get<0>(keys_value); + initTensor = std::get<1>(keys_value); + break; + } + if(!insertIdxFound) + insertIdx++; + } + //Compare use Values with if( auto prevIf = dyn_cast_or_null(parentOp)) { @@ -307,8 +327,7 @@ void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVe elseYieldValues.push_back(it); } } - elseYieldValues.push_back(prevTensor);//TODO: Need to replace this with earliest use of op in the - // given region, prevTensor doesn't work - since this won't work for a chain of connected ops. + elseYieldValues.push_back(initTensor); //Create new Ifop rewriter.setInsertionPoint(prevIf); @@ -342,25 +361,6 @@ void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVe } else if (auto prevFor = dyn_cast_or_null(parentOp)) { - mlir::Value initTensor; - int insertIdx = 0; - - //Find init Tensor for the given for loop, i.e first match to expanded user list - for(auto user: expandedUserList) { - mlir::Region *opRegion = user->getParentRegion(); - if(region->isAncestor(opRegion)) { - //Maintain a map data structure for tracking every user and if they have been processed then the corresponding result - auto it = opResultMap.find(user); - if(it == opResultMap.end()) - continue; - auto keys_value = it->second; - auto op_result = std::get<0>(keys_value); - initTensor = std::get<1>(keys_value); - break; - } - //TODO: Fix this- need to be only updated until we get first region ancestor match - insertIdx++; - } //After first match, now find all the users of the init Tensor in a region. llvm::SmallVector initOpUsers; @@ -410,7 +410,7 @@ void propagateValueThroughRegion(Value ¤tValue, Value &prevTensor, SmallVe rewriter.setInsertionPoint(yieldOp); rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); - //Update prevTensor to use iter_arg + //Update users of initOp to use iterArgs for(auto initOpUser: initOpUsers) { // Iterate over all operands (both inputs and outputs) for (const auto &en : llvm::enumerate(initOpUser->getOperands())) { @@ -535,7 +535,6 @@ struct LinalgDebufferization : public OpRewritePattern { auto toTensorOp = rewriter.create( memVal.getLoc(), tensorType, memVal); Value currentTensor = toTensorOp; - Value prevTensor = toTensorOp; auto sortedUsers = getSortedUsers(memVal); @@ -583,8 +582,7 @@ struct LinalgDebufferization : public OpRewritePattern { } // Propagate value through each region - //TODO: Need this in function form so we can call this after the loop as well - propagateValueThroughRegion(currentTensor, prevTensor, regions, expandedUserList, opResultMap, rewriter); + propagateValueThroughRegion(currentTensor, regions, expandedUserList, opResultMap, rewriter); ArrayAttr indexingMaps = genericOp.getIndexingMaps(); for (auto input : genericOp.getInputs()) { @@ -625,7 +623,6 @@ struct LinalgDebufferization : public OpRewritePattern { // Delete the original genericOp if (newCurrentTensorIndex != -1){ - prevTensor = currentTensor; opResultMap[newGenericOp] = std::make_tuple(newGenericOp.getResult(newCurrentTensorIndex), currentTensor); currentTensor = newGenericOp.getResult(newCurrentTensorIndex); } @@ -648,7 +645,7 @@ struct LinalgDebufferization : public OpRewritePattern { regions.push_back(r); } - propagateValueThroughRegion(currentTensor, prevTensor, regions, expandedUserList, opResultMap, rewriter); + propagateValueThroughRegion(currentTensor, regions, expandedUserList, opResultMap, rewriter); //if(!regions.empty()) { // auto lastRegion = regions.back(); diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index 4d582dced9e8..cee3f8dd82fc 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -102,8 +102,36 @@ return } } - - module @in_place_add_for_loop_carried2{ + module @cross_buffer_add{ + func.func @in_place_add(%buf: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %buf2 = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf : memref<128xf32>) + outs(%buf2 : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf2 : memref<128xf32>) + outs(%buf : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + return + } + } + + module @in_place_add_for_loop_carried_cross_buffer{ func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -134,35 +162,71 @@ return } } - - module @cross_buffer_add{ - func.func @in_place_add(%buf: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - %buf2 = memref.alloca() : memref<128xf32> - linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"] - } ins(%buf : memref<128xf32>) - outs(%buf2 : memref<128xf32>) { - ^bb0(%in: f32, %out: f32): - %sum = arith.addf %in, %value : f32 - linalg.yield %sum : f32 - } - linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"] - } ins(%buf2 : memref<128xf32>) - outs(%buf : memref<128xf32>) { - ^bb0(%in: f32, %out: f32): - %sum = arith.addf %in, %value : f32 - %sum2 = arith.addf %sum, %value : f32 - linalg.yield %sum2 : f32 - } - return - } - } + +// //TODO: Doesn't bufferize --affine loop carried iter_args doesn't canonicalizes (missing pattern?) +// module @in_place_add_for_loop_carried3{ +// func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %c10 = arith.constant 10 : index +// %buffer2 = memref.alloca() : memref<128xf32> +// %result:2 = affine.for %i = %c0 to %c10 iter_args(%buf = %buffer, %buf2 = %buffer2) -> (memref<128xf32>, memref<128xf32>) { +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buf : memref<128xf32>) +// outs(%buf2 : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// linalg.yield %sum : f32 +// } +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buf2 : memref<128xf32>) +// outs(%buf : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// %sum2 = arith.addf %sum, %value : f32 +// linalg.yield %sum2 : f32 +// } +// affine.yield %buf, %buf2 : memref<128xf32>, memref<128xf32> +// } +// return +// } +// } + +// module @in_place_add_for_loop_affine{ +// func.func @in_place_add(%buf: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %c10 = arith.constant 10 : index +// %buf2 = memref.alloca() : memref<128xf32> +// affine.for %i = %c0 to %c10 { +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buf : memref<128xf32>) +// outs(%buf2 : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// linalg.yield %sum : f32 +// } +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buf2 : memref<128xf32>) +// outs(%buf : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// %sum2 = arith.addf %sum, %value : f32 +// linalg.yield %sum2 : f32 +// } +// } +// return +// } +// } + module @in_place_cond_add_followed_by_add{ func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { @@ -198,6 +262,17 @@ %c0 = arith.constant 0 : index //%buffer = memref.alloca() : memref<128xf32> scf.if %cond2 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + %sum3 = arith.addf %sum2, %value : f32 + linalg.yield %sum3 : f32 + } scf.if %cond { linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], @@ -211,6 +286,18 @@ } } } + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] From a570c1bf63029460f335d6f6274030db06683ab9 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 7 Feb 2025 17:10:37 -0800 Subject: [PATCH 49/77] Added more complex region cases with mix of if-else statements --- test/polygeist-opt/debufferize.mlir | 77 ++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index cee3f8dd82fc..65a5a9ef0adf 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -418,4 +418,79 @@ } return %c0_i32 : i32 } - } \ No newline at end of file + } + + module @for_loop_within_for_loop{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.for %i = %c0 to %c10 step %c1 { + scf.for %j = %c0 to %c10 step %c1 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + return + } + } + + module @for_loop_with_if_with_for{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.for %i = %c0 to %c10 step %c1 { + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + scf.for %j = %c0 to %c10 step %c1 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + } + return + } + } From 7ee707b333dacd0d8e681f4c6b96db64643dfd16 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 8 May 2025 05:51:30 -0700 Subject: [PATCH 50/77] Generic solver to represent linalg.generic as kernel.def ops --- generic_solver/CublasDefnPattern.cpp | 260 +++++++++++++++++++++++++++ generic_solver/CublasOps.td | 85 +++++++++ generic_solver/cublas_example.mlir | 238 ++++++++++++++++++++++++ 3 files changed, 583 insertions(+) create mode 100644 generic_solver/CublasDefnPattern.cpp create mode 100644 generic_solver/CublasOps.td create mode 100644 generic_solver/cublas_example.mlir diff --git a/generic_solver/CublasDefnPattern.cpp b/generic_solver/CublasDefnPattern.cpp new file mode 100644 index 000000000000..c9e583affd4b --- /dev/null +++ b/generic_solver/CublasDefnPattern.cpp @@ -0,0 +1,260 @@ +//===- KernelDefnPattern.cpp - Pattern to match linalg.generic with kernel.defn ------===// +// +// This file implements a pattern to rewrite linalg.generic operations to kernel +// operations by matching against patterns defined in kernel.defn_collection. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "KernelOps.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +// Helper function to check if two regions are structurally equivalent +bool areRegionsEquivalent(Region &first, Region &second) { + // Compare number of blocks + if (first.getBlocks().size() != second.getBlocks().size()) + return false; + + // Compare corresponding blocks + for (auto blockPair : llvm::zip(first.getBlocks(), second.getBlocks())) { + Block &firstBlock = std::get<0>(blockPair); + Block &secondBlock = std::get<1>(blockPair); + + // Compare number of arguments + if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) + return false; + + // Compare argument types + for (auto argPair : llvm::zip(firstBlock.getArguments(), + secondBlock.getArguments())) { + if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) + return false; + } + + // Compare operations (simplified - real implementation would be more complex) + if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) + return false; + + // For a full implementation, you'd need more sophisticated operation comparison + // based on operands, attributes, and result types + } + + return true; +} + +// Helper to check if indexing maps are equivalent +bool areIndexingMapsEquivalent(ArrayAttr firstMaps, ArrayAttr secondMaps) { + if (firstMaps.size() != secondMaps.size()) + return false; + + for (auto mapPair : llvm::zip(firstMaps, secondMaps)) { + auto firstMap = std::get<0>(mapPair).cast().getValue(); + auto secondMap = std::get<1>(mapPair).cast().getValue(); + + if (firstMap != secondMap) + return false; + } + + return true; +} + +// Helper to check if iterator types are equivalent +bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { + if (firstTypes.size() != secondTypes.size()) + return false; + + for (auto typePair : llvm::zip(firstTypes, secondTypes)) { + auto firstType = std::get<0>(typePair).cast().getValue(); + auto secondType = std::get<1>(typePair).cast().getValue(); + + if (firstType != secondType) + return false; + } + + return true; +} + +// Check if a linalg.generic operation matches a kernel.defn in a collection +FailureOr matchGenericWithDefn( + GenericOp genericOp, + kernel::DefnCollectionOp collectionOp) { + + // Get attributes from the generic operation + ArrayAttr indexingMaps = genericOp.getIndexingMapsAttr(); + ArrayAttr iteratorTypes = genericOp.getIteratorTypesAttr(); + unsigned numInputs = genericOp.getNumDpsInputs(); + unsigned numOutputs = genericOp.getNumDpsInits(); + + // Walk through each defn in the collection + for (Operation &op : collectionOp.getDefns()) { + auto defnOp = cast(op); + StringAttr opName = defnOp.getNameAttr(); + + // Check for linalg.generic in the defn's body + bool foundMatch = false; + defnOp.getBody().walk([&](GenericOp candidateOp) { + // Skip if already found a match + if (foundMatch) + return; + + // Check if this linalg.generic matches our target + if (candidateOp.getNumDpsInputs() == numInputs && + candidateOp.getNumDpsInits() == numOutputs && + //TODO: Generalize to a single dialect, with no special ops + //TODO: Indexing maps and orders might differ + //TODO: More complex case- where extra loops exists around the ops we have + //TODO: Custom cost model ? + //TODO: Constants might require special handling such as bounds + //IDEA: Descheduling / removing tiles + areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && + areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && + areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) { + foundMatch = true; + } + }); + + if (foundMatch) + return opName.str(); + } + + return failure(); +} + +// Rewrite pattern to convert linalg.generic to kernel ops +class LinalgGenericToKernelPattern : public OpRewritePattern { +public: + LinalgGenericToKernelPattern(MLIRContext *context, + kernel::DefnCollectionOp collectionOp) + : OpRewritePattern(context), collectionOp(collectionOp) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + // Try to match with a defn in the collection + auto matchResult = matchGenericWithDefn(genericOp, collectionOp); + if (failed(matchResult)) + return failure(); + + std::string opName = *matchResult; + + // Create the appropriate kernel operation based on the matched pattern + if (opName == "Kernel_gemm") { + // Get inputs and outputs + Value outputTensor = genericOp.getDpsInitOperand(0)->get(); + Value inputA = genericOp.getDpsInputOperand(0)->get(); + Value inputB = genericOp.getDpsInputOperand(1)->get(); + + // Default alpha and beta values (could be extracted from pattern) + FloatAttr alpha = rewriter.getF32FloatAttr(1.0); + FloatAttr beta = rewriter.getF32FloatAttr(0.0); + + // Create the kernel.gemm operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), + outputTensor, inputA, inputB, alpha, beta); + + return success(); + } + else if (opName == "Kernel_batched_gemm") { + // Get inputs and outputs + Value outputTensor = genericOp.getDpsInitOperand(0)->get(); + Value inputA = genericOp.getDpsInputOperand(0)->get(); + Value inputB = genericOp.getDpsInputOperand(1)->get(); + + // Default alpha and beta values + FloatAttr alpha = rewriter.getF32FloatAttr(1.0); + FloatAttr beta = rewriter.getF32FloatAttr(0.0); + + // Create the kernel.batched_gemm operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), + outputTensor, inputA, inputB, alpha, beta); + + return success(); + } + else if (opName == "Kernel_iamax") { + // Get input + Value input = genericOp.getDpsInputOperand(0)->get(); + + // Create the kernel.iamax operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), input); + + return success(); + } + else if (opName == "Kernel_iamin") { + // Get input + Value input = genericOp.getDpsInputOperand(0)->get(); + + // Create the kernel.iamin operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), input); + + return success(); + } + else if (opName == "Kernel_asum") { + // Get input + Value input = genericOp.getDpsInputOperand(0)->get(); + + // Create the kernel.asum operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), input); + + return success(); + } + + return failure(); + } + +private: + kernel::DefnCollectionOp collectionOp; +}; + +// Pass to apply the rewrite pattern +class LinalgToKernelPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgToKernelPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + + // Find the kernel.defn_collection in the module + kernel::DefnCollectionOp collectionOp; + module.walk([&](kernel::DefnCollectionOp op) { + collectionOp = op; + return WalkResult::interrupt(); + }); + + if (!collectionOp) { + module.emitError("No kernel.defn_collection found in module"); + return signalPassFailure(); + } + + // Apply the rewrite pattern + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), collectionOp); + + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +// Create a pass to convert linalg.generic to kernel +std::unique_ptr createLinalgToKernelPass() { + return std::make_unique(); +} + +// Register the pass +void registerLinalgToKernelPasses() { + PassRegistration("linalg-to-kernel", + "Convert linalg.generic to kernel operations"); +} \ No newline at end of file diff --git a/generic_solver/CublasOps.td b/generic_solver/CublasOps.td new file mode 100644 index 000000000000..56aaebba0766 --- /dev/null +++ b/generic_solver/CublasOps.td @@ -0,0 +1,85 @@ +//===- KernelOps.td - kernel dialect operation definitions ---*- tablegen -*-===// +// +// This file defines the kernel operation definitions in TableGen format. +// +//===----------------------------------------------------------------------===// + +#ifndef kernel_OPS +#define kernel_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" + +//===----------------------------------------------------------------------===// +// kernel dialect definition +//===----------------------------------------------------------------------===// + +def Kernel_Dialect : Dialect { + let name = "kernel"; + let cppNamespace = "::mlir::kernel"; + let description = [{ + The kernel dialect provides operations for NVIDIA kernel matrix multiplication + routines, including standard and batched GEMM operations. + }]; +} + +//===----------------------------------------------------------------------===// +// Base class for kernel dialect operations +//===----------------------------------------------------------------------===// + +class Kernel_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// kernel ops instantiation collection +//===----------------------------------------------------------------------===// + +def Opinst_DefnCollection : Op { + let summary = "Collection of operation definitions"; + let description = [{ + A collection of operation definitions that can be referenced elsewhere. + This operation serves as a container for multiple operation definitions. + }]; + + let regions = (region SizedRegion<1>:$defns); + + let assemblyFormat = [{ + $defns attr-dict + }]; +} + +def Opinst_Defn : Op { + let summary = "Definition of an operation"; + let description = [{ + A definition of an operation with inputs and arbitrary body code. + Can contain either literal code or a linalg.generic representation. + }]; + + let arguments = (ins + StrAttr:$name, + Variadic:$inputs + ); + + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = [{ + $name `(` $inputs `)` $body attr-dict `:` functional-type($inputs, results) + }]; +} + +//===----------------------------------------------------------------------===// +// Example pattern representation +//===----------------------------------------------------------------------===// + +// Patterns for gemm and batched_gemm expressed in a mathematical notation. +// These are informational and would be used by pattern matchers. + +// Standard GEMM pattern: C(i,k) += alpha * A(i,j) * B(j,k) +// Batched GEMM pattern: C(N, i,k) += alpha * A(N, i,j) * B(N, j,k) + +// Index of max absolute value pattern: result = argmax_i |x_i| +// Index of min absolute value pattern: result = argmin_i |x_i| +// Sum of absolute values pattern: result = sum_i |x_i| + +#endif // kernel_OPS \ No newline at end of file diff --git a/generic_solver/cublas_example.mlir b/generic_solver/cublas_example.mlir new file mode 100644 index 000000000000..f444871c62da --- /dev/null +++ b/generic_solver/cublas_example.mlir @@ -0,0 +1,238 @@ +// Example MLIR module demonstrating kernel operations and their linalg.generic representations +module { + // Define a collection of kernel operation definitions + kernel.defn_collection { + // GEMM operation definition with arbitrary code implementation + kernel.defn "gemm" (%A : tensor, %B : tensor, %C : tensor) { + // This could include arbitrary code to implement the GEMM operation + // For example, calling into the actual kernel library + "some.custom_code"() : () -> () + } : (tensor, tensor, tensor) -> () + + // GEMM operation definition with linalg.generic representation + kernel.defn "gemm" (%A : tensor, %B : tensor, %C : tensor) { + %alpha = arith.constant 1.0 : f32 + %beta = arith.constant 0.0 : f32 + + // Implementation using linalg.generic + linalg.generic { + indexing_maps = [ + affine_map<(i, j, k) -> (i, k)>, // A(i,k) + affine_map<(i, j, k) -> (k, j)>, // B(k,j) + affine_map<(i, j, k) -> (i, j)> // C(i,j) + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %scaled = arith.mulf %product, %alpha : f32 + %scaled_c = arith.mulf %c, %beta : f32 + %result = arith.addf %scaled, %scaled_c : f32 + linalg.yield %result : f32 + } + } : (tensor, tensor, tensor) -> () + + // Batched GEMM operation definition with arbitrary code + kernel.defn "batched_gemm" (%A : tensor, %B : tensor, %C : tensor) { + // This could include arbitrary code to implement the batched GEMM operation + "some.custom_code"() : () -> () + } : (tensor, tensor, tensor) -> () + + // Batched GEMM operation definition with linalg.generic representation + kernel.defn "batched_gemm" (%A : tensor, %B : tensor, %C : tensor) { + %alpha = arith.constant 1.0 : f32 + %beta = arith.constant 0.0 : f32 + + // Implementation using linalg.generic + linalg.generic { + indexing_maps = [ + affine_map<(b, i, j, k) -> (b, i, k)>, // A(b,i,k) + affine_map<(b, i, j, k) -> (b, k, j)>, // B(b,k,j) + affine_map<(b, i, j, k) -> (b, i, j)> // C(b,i,j) + ], + iterator_types = ["parallel", "parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %scaled = arith.mulf %product, %alpha : f32 + %scaled_c = arith.mulf %c, %beta : f32 + %result = arith.addf %scaled, %scaled_c : f32 + linalg.yield %result : f32 + } + } : (tensor, tensor, tensor) -> () + + // Index of maximum absolute value operation definition with arbitrary code + kernel.defn "iamax" (%X : tensor) { + // This could include arbitrary code to find the index of max absolute value + "some.custom_code"() : () -> () + } : (tensor) -> tensor + + // Index of maximum absolute value operation definition with linalg.generic representation + kernel.defn "iamax" (%X : tensor) { + // Create an initial tensor to store the result index + %c0 = arith.constant 0 : i32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : i32) outs(%init : tensor) -> tensor + + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (index) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%fill : tensor) { + ^bb0(%in: f32, %out: i32): + %idx = linalg.index 0 : index + %abs_val = math.absf %in : f32 + %curr_max_idx = arith.index_cast %out : i32 to index + %curr_max = tensor.extract %X[%curr_max_idx] : tensor + %curr_max_abs = math.absf %curr_max : f32 + %cmp = arith.cmpf ogt, %abs_val, %curr_max_abs : f32 + %new_idx = arith.select %cmp, %idx, %curr_max_idx : index + %result = arith.index_cast %new_idx : index to i32 + linalg.yield %result : i32 + } + } : (tensor) -> tensor + + // Index of minimum absolute value operation definition with arbitrary code + kernel.defn "iamin" (%X : tensor) { + // This could include arbitrary code to find the index of min absolute value + "some.custom_code"() : () -> () + } : (tensor) -> tensor + + // Index of minimum absolute value operation definition with linalg.generic representation + kernel.defn "iamin" (%X : tensor) { + // Create an initial tensor to store the result index + %c0 = arith.constant 0 : i32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : i32) outs(%init : tensor) -> tensor + + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (index) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%fill : tensor) { + ^bb0(%in: f32, %out: i32): + %idx = linalg.index 0 : index + %abs_val = math.absf %in : f32 + %curr_min_idx = arith.index_cast %out : i32 to index + %curr_min = tensor.extract %X[%curr_min_idx] : tensor + %curr_min_abs = math.absf %curr_min : f32 + %cmp = arith.cmpf olt, %abs_val, %curr_min_abs : f32 + %new_idx = arith.select %cmp, %idx, %curr_min_idx : index + %result = arith.index_cast %new_idx : index to i32 + linalg.yield %result : i32 + } + } : (tensor) -> tensor + + // Sum of absolute values operation definition with arbitrary code + kernel.defn "asum" (%X : tensor) { + // This could include arbitrary code to compute the sum of absolute values + "some.custom_code"() : () -> () + } : (tensor) -> tensor + + // Sum of absolute values operation definition with linalg.generic representation + kernel.defn "asum" (%X : tensor) { + // Create an initial tensor to store the result sum + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (sum) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%fill : tensor) { + ^bb0(%in: f32, %out: f32): + %abs_val = math.absf %in : f32 + %result = arith.addf %abs_val, %out : f32 + linalg.yield %result : f32 + } + } : (tensor) -> tensor + + // Mathematical definitions (commented, for reference) + // kernel.defn "gemm" (...) { + // C(i,j) += alpha * A(i,k) * B(k,j); + // } + + // kernel.defn "batched_gemm" (...) { + // C(b,i,j) += alpha * A(b,i,k) * B(b,k,j); + // } + + // kernel.defn "iamax" (...) { + // result = argmax_i |x_i|; + // } + + // kernel.defn "iamin" (...) { + // result = argmin_i |x_i|; + // } + + // kernel.defn "asum" (...) { + // result = sum_i |x_i|; + // } + } + + // Main function showing usage of the operations + func.func @main() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Allocate tensors for matrices + %A = tensor.empty() : tensor<2x128x64xf32> + %B = tensor.empty() : tensor<2x64x256xf32> + %C = tensor.empty() : tensor<2x128x256xf32> + + // Allocate a vector for vector operations + %X = tensor.empty() : tensor<128xf32> + + // Get slices of the batched tensors + %A0 = tensor.extract_slice %A[0, 0, 0][1, 128, 64][1, 1, 1] : tensor<2x128x64xf32> to tensor<128x64xf32> + %B0 = tensor.extract_slice %B[0, 0, 0][1, 64, 256][1, 1, 1] : tensor<2x64x256xf32> to tensor<64x256xf32> + %C0 = tensor.extract_slice %C[0, 0, 0][1, 128, 256][1, 1, 1] : tensor<2x128x256xf32> to tensor<128x256xf32> + + %A1 = tensor.extract_slice %A[1, 0, 0][1, 128, 64][1, 1, 1] : tensor<2x128x64xf32> to tensor<128x64xf32> + %B1 = tensor.extract_slice %B[1, 0, 0][1, 64, 256][1, 1, 1] : tensor<2x64x256xf32> to tensor<64x256xf32> + %C1 = tensor.extract_slice %C[1, 0, 0][1, 128, 256][1, 1, 1] : tensor<2x128x256xf32> to tensor<128x256xf32> + + // Perform individual GEMM operations on slices + // Using kernel.defn operation + kernel.defn(%A0, %B0, %C0) {kernel_name = "gemm"} : + (tensor<128x64xf32>, tensor<64x256xf32>, tensor<128x256xf32>) -> () + + kernel.defn(%A1, %B1, %C1) {kernel_name = "gemm"} : + (tensor<128x64xf32>, tensor<64x256xf32>, tensor<128x256xf32>) -> () + + // Perform batched GEMM operation + // Using kernel.defn operation + kernel.defn(%A, %B, %C) {kernel_name = "batched_gemm"} : + (tensor<2x128x64xf32>, tensor<2x64x256xf32>, tensor<2x128x256xf32>) -> () + + // Perform vector operations + + // Find index of maximum absolute value + %max_idx = kernel.defn(%X) {kernel_name = "iamax"} : + (tensor<128xf32>) -> tensor + + // Find index of minimum absolute value + %min_idx = kernel.defn(%X) {kernel_name = "iamin"} : + (tensor<128xf32>) -> tensor + + // Calculate sum of absolute values + %abs_sum = kernel.defn(%X) {kernel_name = "asum"} : + (tensor<128xf32>) -> tensor + + return + } +} \ No newline at end of file From c8561b428667f6670212fa4b24a638a172260e28 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Mon, 12 May 2025 15:48:16 -0700 Subject: [PATCH 51/77] Adding cases for generic solver --- generic_solver/CublasDefnPattern.cpp | 115 +++++++++++++++++++++++++-- 1 file changed, 107 insertions(+), 8 deletions(-) diff --git a/generic_solver/CublasDefnPattern.cpp b/generic_solver/CublasDefnPattern.cpp index c9e583affd4b..16515e13d1cd 100644 --- a/generic_solver/CublasDefnPattern.cpp +++ b/generic_solver/CublasDefnPattern.cpp @@ -31,19 +31,31 @@ bool areRegionsEquivalent(Region &first, Region &second) { if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) return false; - // Compare argument types + //// Compare argument types + //for (auto argPair : llvm::zip(firstBlock.getArguments(), + // secondBlock.getArguments())) { + // if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) + // return false; + //} + + //Traverse the use-def chain of the arguments and compare the operation names for (auto argPair : llvm::zip(firstBlock.getArguments(), secondBlock.getArguments())) { - if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) + if (std::get<0>(argPair).getName() != std::get<1>(argPair).getName()) return false; + //Traverse the use-def chain of the argument + for (auto use : std::get<0>(argPair).getUses()) { + if (use.getOwner().getName() != std::get<1>(argPair).getName()) + return false; + } } - // Compare operations (simplified - real implementation would be more complex) - if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) - return false; + //// Compare operations (simplified - real implementation would be more complex) + //if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) + // return false; - // For a full implementation, you'd need more sophisticated operation comparison - // based on operands, attributes, and result types + //// For a full implementation, you'd need more sophisticated operation comparison + //// based on operands, attributes, and result types } return true; @@ -81,6 +93,83 @@ bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { return true; } +// Cases: +// 1. What if they do a*(b+c) as a*b+a*c ? +// 2. What is they do (a+b)/c as a/c+b/c ? +// - The required best form can vary based on a cost model for a given architecture +// - The expectation is that kernel.defn is the best form an op is expected to take +// - The generic solver will employ heuristics to match the best form +// - Heuristics can be as simple as "is the op a commutative operation ?", +// "is the op an associative operation ?", "is the op distributive ?", etc. +// 3. What if the order of operations is different ? add(a,b) as add(b,a) +// - This requires a commutative check for operations, i.e in commutative ops +// we don't need to match positions +// 4. What if order of uses are different for an op? Eg- +// a1 = ... | a2 = ... +// b1 = a1/c1 | d2 = a2*c2 +// d1 = a1*c1 | b2 = a2/c2 +// - In this case, we need to find the corresponding uses of the operands +// 5. + +// Non-recursive traversal of use-def chain using a stack +bool compareUseDefChains(Value firstValue, Value secondValue) { + // Use a std::stack to track operations we need to visit + std::stack> workList; + std::set> visited; + + // Start with the initial values + workList.push({firstValue, secondValue}); + + while (!workList.empty()) { + auto [value1, value2] = workList.top(); + workList.pop(); + + // Skip if we've already processed this pair + auto valuePtrPair = std::make_pair(value1.getImpl(), value2.getImpl()); + if (visited.count(valuePtrPair)) + continue; + visited.insert(valuePtrPair); + + // Compare the values themselves + if (value1.getType() != value2.getType()) + return false; + + // Compare all uses + auto uses1 = value1.getUses(); + auto uses2 = value2.getUses(); + + // Process each use + for (auto &use1 : uses1) { + Operation *op1 = use1.getOwner(); + + // Find corresponding use in second value + bool foundMatch = false; + for (auto &use2 : uses2) { + Operation *op2 = use2.getOwner(); + + // Compare operations (customize based on your definition of equivalence) + if (op1->getName() == op2->getName() && + //This requires a commutative check + use1.getOperandNumber() == use2.getOperandNumber()) { + foundMatch = true; + + // Add results to worklist to continue traversal + for (unsigned i = 0; i < op1->getNumResults(); ++i) { + if (i < op2->getNumResults()) + workList.push({op1->getResult(i), op2->getResult(i)}); + } + break; + } + } + + if (!foundMatch) + return false; + } + } + + return true; +} + // Check if a linalg.generic operation matches a kernel.defn in a collection FailureOr matchGenericWithDefn( GenericOp genericOp, @@ -107,12 +196,22 @@ FailureOr matchGenericWithDefn( // Check if this linalg.generic matches our target if (candidateOp.getNumDpsInputs() == numInputs && candidateOp.getNumDpsInits() == numOutputs && - //TODO: Generalize to a single dialect, with no special ops + //DONE: Generalize to a single dialect, with no special ops //TODO: Indexing maps and orders might differ //TODO: More complex case- where extra loops exists around the ops we have //TODO: Custom cost model ? //TODO: Constants might require special handling such as bounds //IDEA: Descheduling / removing tiles + int numOfIndexingMaps = indexingMaps.size(); + int combinations = calculate_combinations(numOfIndexingMaps); + int calculatedCombinations(int numOfPos) { + //Calculate factorial of numOfPos + int result = 1; + for (int i = 1; i <= numOfPos; i++) { + result *= i; + } + return result; + } areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) { From 07d0dcb975ebb9a9ce5931203f59c95df6842b4a Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Tue, 27 May 2025 17:42:08 -0700 Subject: [PATCH 52/77] Backup of previous edits --- generic_solver/CublasDefnPattern.cpp | 155 ++++++++++++------------ lib/polygeist/Passes/RaiseToLinalg.cpp | 4 +- lib/polygeist/Passes/RemoveIterArgs.cpp | 152 ++++++++++------------- 3 files changed, 146 insertions(+), 165 deletions(-) diff --git a/generic_solver/CublasDefnPattern.cpp b/generic_solver/CublasDefnPattern.cpp index 16515e13d1cd..4a62fb8345da 100644 --- a/generic_solver/CublasDefnPattern.cpp +++ b/generic_solver/CublasDefnPattern.cpp @@ -16,83 +16,6 @@ using namespace mlir::linalg; namespace { -// Helper function to check if two regions are structurally equivalent -bool areRegionsEquivalent(Region &first, Region &second) { - // Compare number of blocks - if (first.getBlocks().size() != second.getBlocks().size()) - return false; - - // Compare corresponding blocks - for (auto blockPair : llvm::zip(first.getBlocks(), second.getBlocks())) { - Block &firstBlock = std::get<0>(blockPair); - Block &secondBlock = std::get<1>(blockPair); - - // Compare number of arguments - if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) - return false; - - //// Compare argument types - //for (auto argPair : llvm::zip(firstBlock.getArguments(), - // secondBlock.getArguments())) { - // if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) - // return false; - //} - - //Traverse the use-def chain of the arguments and compare the operation names - for (auto argPair : llvm::zip(firstBlock.getArguments(), - secondBlock.getArguments())) { - if (std::get<0>(argPair).getName() != std::get<1>(argPair).getName()) - return false; - //Traverse the use-def chain of the argument - for (auto use : std::get<0>(argPair).getUses()) { - if (use.getOwner().getName() != std::get<1>(argPair).getName()) - return false; - } - } - - //// Compare operations (simplified - real implementation would be more complex) - //if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) - // return false; - - //// For a full implementation, you'd need more sophisticated operation comparison - //// based on operands, attributes, and result types - } - - return true; -} - -// Helper to check if indexing maps are equivalent -bool areIndexingMapsEquivalent(ArrayAttr firstMaps, ArrayAttr secondMaps) { - if (firstMaps.size() != secondMaps.size()) - return false; - - for (auto mapPair : llvm::zip(firstMaps, secondMaps)) { - auto firstMap = std::get<0>(mapPair).cast().getValue(); - auto secondMap = std::get<1>(mapPair).cast().getValue(); - - if (firstMap != secondMap) - return false; - } - - return true; -} - -// Helper to check if iterator types are equivalent -bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { - if (firstTypes.size() != secondTypes.size()) - return false; - - for (auto typePair : llvm::zip(firstTypes, secondTypes)) { - auto firstType = std::get<0>(typePair).cast().getValue(); - auto secondType = std::get<1>(typePair).cast().getValue(); - - if (firstType != secondType) - return false; - } - - return true; -} - // Cases: // 1. What if they do a*(b+c) as a*b+a*c ? // 2. What is they do (a+b)/c as a/c+b/c ? @@ -170,6 +93,84 @@ bool compareUseDefChains(Value firstValue, Value secondValue) { return true; } + +// Helper function to check if two regions are structurally equivalent +bool areRegionsEquivalent(Region &first, Region &second) { + // Compare number of blocks + if (first.getBlocks().size() != second.getBlocks().size()) + return false; + + // Compare corresponding blocks + for (auto blockPair : llvm::zip(first.getBlocks(), second.getBlocks())) { + Block &firstBlock = std::get<0>(blockPair); + Block &secondBlock = std::get<1>(blockPair); + + // Compare number of arguments + if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) + return false; + + //// Compare argument types + //for (auto argPair : llvm::zip(firstBlock.getArguments(), + // secondBlock.getArguments())) { + // if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) + // return false; + //} + + //Traverse the use-def chain of the arguments and compare the operation names + for (auto argPair : llvm::zip(firstBlock.getArguments(), + secondBlock.getArguments())) { + if (std::get<0>(argPair).getName() != std::get<1>(argPair).getName()) + return false; + //Traverse the use-def chain of the argument + for (auto use : std::get<0>(argPair).getUses()) { + if (use.getOwner().getName() != std::get<1>(argPair).getName()) + return false; + } + } + + //// Compare operations (simplified - real implementation would be more complex) + //if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) + // return false; + + //// For a full implementation, you'd need more sophisticated operation comparison + //// based on operands, attributes, and result types + } + + return true; +} + +// Helper to check if indexing maps are equivalent +bool areIndexingMapsEquivalent(ArrayAttr firstMaps, ArrayAttr secondMaps) { + if (firstMaps.size() != secondMaps.size()) + return false; + + for (auto mapPair : llvm::zip(firstMaps, secondMaps)) { + auto firstMap = std::get<0>(mapPair).cast().getValue(); + auto secondMap = std::get<1>(mapPair).cast().getValue(); + + if (firstMap != secondMap) + return false; + } + + return true; +} + +// Helper to check if iterator types are equivalent +bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { + if (firstTypes.size() != secondTypes.size()) + return false; + + for (auto typePair : llvm::zip(firstTypes, secondTypes)) { + auto firstType = std::get<0>(typePair).cast().getValue(); + auto secondType = std::get<1>(typePair).cast().getValue(); + + if (firstType != secondType) + return false; + } + + return true; +} + // Check if a linalg.generic operation matches a kernel.defn in a collection FailureOr matchGenericWithDefn( GenericOp genericOp, diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index fee0e4d157a7..f638a26c9cd1 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -619,7 +619,7 @@ struct AffineForOpRaising : public OpRewritePattern { stores_map[load] = store; continue; } - return failure(); + //return failure(); } } for (auto &&[_, store2] : stores) { @@ -915,7 +915,7 @@ struct AffineForOpRaising : public OpRewritePattern { // This index will replace the use of the affine index auto idx = rewriter.create(loop.getLoc(), - rewriter.getIndexAttr(0)); + 0); rewriter.replaceAllUsesWith(loop.getInductionVar(), idx); auto &body = genericOp.getRegion(); diff --git a/lib/polygeist/Passes/RemoveIterArgs.cpp b/lib/polygeist/Passes/RemoveIterArgs.cpp index 0a4784c6c599..2a3e9ea4edc6 100644 --- a/lib/polygeist/Passes/RemoveIterArgs.cpp +++ b/lib/polygeist/Passes/RemoveIterArgs.cpp @@ -144,128 +144,108 @@ struct RemoveSCFIterArgs : public OpRewritePattern { } }; +// General Case(TODO): +// ALGo: +// 1. Create an alloca(stack) variable +// How to know it's dims? It should be based on number of reduction +// loops +// 2. Initialize it with init value just outside the for loop if init +// value is non-zero +// 3. memref.load that value in the for loop +// 4. Replace all the uses of the iter_arg with the loaded value +// 5. Add a memref.store for the value to be yielded +// 6. Replace all uses of for-loops yielded value with a single inserted +// memref.load +// Special case: +// ALGo: +// Optimize away memref.store and memref.load, if the only users of +// memref.load are memref.store (can use affine-scalrep pass for that ? No +// it does store to load forwarding) What we need is forwarding of local +// store to final store and deleting the intermediate alloca created. This +// is only possible if the user of alloca is a storeOp. +// 1. Identify the single store of the for loop result +// 2. Initialize it with iter arg init, outside the for loop. (TODO) +// 3. Do a load from the memref +// 4. move the store to memref inside the loop. + struct RemoveAffineIterArgs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(affine::AffineForOp forOp, PatternRewriter &rewriter) const override { ModuleOp module = forOp->getParentOfType(); - if (!forOp.getRegion().hasOneBlock()) - return failure(); + rewriter.setInsertionPoint(forOp); + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + if (numIterArgs == 0) + return failure(); + auto loc = forOp->getLoc(); - bool changed = false; - llvm::SetVector removed; - llvm::MapVector steps; auto yieldOp = cast(forOp.getBody()->getTerminator()); - for (unsigned i = 0; i < numIterArgs; i++) { - auto ba = forOp.getRegionIterArgs()[i]; - auto init = forOp.getInits()[i]; - auto lastOp = yieldOp->getOperand(i); - - // General Case(TODO): - // ALGo: - // 1. Create an alloca(stack) variable - // How to know it's dims? It should be based on number of reduction - // loops - // 2. Initialize it with init value just outside the for loop if init - // value is non-zero - // 3. memref.load that value in the for loop - // 4. Replace all the uses of the iter_arg with the loaded value - // 5. Add a memref.store for the value to be yielded - // 6. Replace all uses of for-loops yielded value with a single inserted - // memref.load - // Special case: - // ALGo: - // Optimize away memref.store and memref.load, if the only users of - // memref.load are memref.store (can use affine-scalrep pass for that ? No - // it does store to load forwarding) What we need is forwarding of local - // store to final store and deleting the intermediate alloca created. This - // is only possible if the user of alloca is a storeOp. - // 1. Identify the single store of the for loop result - // 2. Initialize it with iter arg init, outside the for loop. (TODO) - // 3. Do a load from the memref - // 4. move the store to memref inside the loop. - auto result = forOp.getResult(i); - if (result.hasOneUse()) { - auto storeOp = - dyn_cast(*result.getUsers().begin()); - if (storeOp) { - { - rewriter.setInsertionPointToStart(forOp.getBody()); - auto memrefLoad = rewriter.create( - forOp.getLoc(), storeOp.getMemref(), storeOp.getMap(), - storeOp.getMapOperands()); - rewriter.replaceAllUsesWith(ba, memrefLoad.getResult()); - } - { - rewriter.setInsertionPoint(yieldOp); - rewriter.create( - forOp.getLoc(), lastOp, storeOp.getMemref(), storeOp.getMap(), - storeOp.getMapOperands()); - storeOp.erase(); - } - } else { - return failure(); + auto ba = forOp.getRegionIterArgs()[numIterArgs - 1]; + auto init = forOp.getInits()[numIterArgs - 1]; + auto lastOp = yieldOp->getOperand(numIterArgs - 1); + + auto result = forOp.getResult(numIterArgs - 1); + if (result.hasOneUse()) { + auto storeOp = + dyn_cast(*result.getUsers().begin()); + if (storeOp) { + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(forOp.getBody()); + auto memrefLoad = rewriter.create( + forOp.getLoc(), storeOp.getMemref(), storeOp.getMap(), + storeOp.getMapOperands()); + rewriter.replaceAllUsesWith(ba, memrefLoad.getResult()); + } + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(yieldOp); + rewriter.create( + forOp.getLoc(), lastOp, storeOp.getMemref(), storeOp.getMap(), + storeOp.getMapOperands()); + storeOp.erase(); } + } else { + return failure(); } - // else{ - // alloca = rewriter.create( - // forOp.getLoc(), MemRefType::get(ArrayRef(), - // forOp.getType()), ValueRange()); - // //Skipping init for now - - // auto memrefLoad = rewriter.create( - // forOp.getLoc(), alloca.getMemref(), op.getIndices()); - // rewriter.replaceOp(op, memrefLoad.getResult()); - - // rewriter.create(forOp.getLoc(), lastOp, alloca, - // forOp.getBody()->getArguments()); - - // rewriter.replaceAllUsesWith(result,) - //} - - rewriter.setInsertionPointToStart(forOp.getBody()); - // rewriter.replaceAllUsesWith(ba, replacementIV); - changed = true; } - - if (!changed) + else{ return failure(); + } - rewriter.setInsertionPoint(forOp); + SmallVector newIterArgs(forOp.getInits().drop_back()); auto newForOp = rewriter.create( loc, forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), - forOp.getStep()); + forOp.getStep(), newIterArgs); if (!newForOp.getRegion().empty()) newForOp.getRegion().front().erase(); - assert(newForOp.getRegion().empty()); rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), newForOp.getRegion().begin()); // Delete region args llvm::BitVector toDelete(numIterArgs + 1); - for (unsigned i = 0; i < numIterArgs; i++) - toDelete[i + 1] = true; + toDelete[numIterArgs] = true; newForOp.getBody()->eraseArguments(toDelete); SmallVector newYields; { + OpBuilder::InsertionGuard guard(rewriter); ValueRange empty; rewriter.setInsertionPoint(yieldOp); - auto newYieldOp = rewriter.create(loc); - // rewriter.replaceOpWithNewOp(yieldOp, - // newYieldOp); - rewriter.eraseOp(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getOperands().drop_back()); } - rewriter.setInsertionPoint(newForOp); - rewriter.eraseOp(forOp); + for(int i = 0; i < numIterArgs-1; i++){ + rewriter.replaceAllUsesWith(forOp.getResult(i), newForOp.getResult(i)); + } + rewriter.eraseOp(forOp); return success(); } }; From 009ab9be809a138c31020bde22c8a53f2f607dd9 Mon Sep 17 00:00:00 2001 From: arjaiswal Date: Wed, 11 Jun 2025 10:47:04 -0700 Subject: [PATCH 53/77] Temp changes for kernel dialect --- include/polygeist/CMakeLists.txt | 3 +- include/polygeist/Kernel/CMakeLists.txt | 1 + include/polygeist/Kernel/KernelDialect.h | 25 ++++++ include/polygeist/Kernel/KernelDialect.td | 36 ++++++++ include/polygeist/Kernel/KernelOps.h | 32 +++++++ include/polygeist/Kernel/KernelOps.td | 103 ++++++++++++++++++++++ lib/polygeist/CMakeLists.txt | 1 + lib/polygeist/Kernel/CMakeLists.txt | 19 ++++ lib/polygeist/Kernel/KernelDialect.cpp | 33 +++++++ lib/polygeist/Kernel/KernelOps.cpp | 79 +++++++++++++++++ tools/polygeist-opt/CMakeLists.txt | 1 + tools/polygeist-opt/polygeist-opt.cpp | 3 + 12 files changed, 335 insertions(+), 1 deletion(-) create mode 100644 include/polygeist/Kernel/CMakeLists.txt create mode 100644 include/polygeist/Kernel/KernelDialect.h create mode 100644 include/polygeist/Kernel/KernelDialect.td create mode 100644 include/polygeist/Kernel/KernelOps.h create mode 100644 include/polygeist/Kernel/KernelOps.td create mode 100644 lib/polygeist/Kernel/CMakeLists.txt create mode 100644 lib/polygeist/Kernel/KernelDialect.cpp create mode 100644 lib/polygeist/Kernel/KernelOps.cpp diff --git a/include/polygeist/CMakeLists.txt b/include/polygeist/CMakeLists.txt index efcf93f70329..06fb9a05da90 100644 --- a/include/polygeist/CMakeLists.txt +++ b/include/polygeist/CMakeLists.txt @@ -2,4 +2,5 @@ add_mlir_dialect(PolygeistOps polygeist) add_mlir_doc(PolygeistDialect -gen-dialect-doc PolygeistDialect Polygeist/) add_mlir_doc(PolygeistOps -gen-op-doc PolygeistOps Polygeist/) -add_subdirectory(Passes) \ No newline at end of file +add_subdirectory(Passes) +add_subdirectory(Kernel) \ No newline at end of file diff --git a/include/polygeist/Kernel/CMakeLists.txt b/include/polygeist/Kernel/CMakeLists.txt new file mode 100644 index 000000000000..6bc7f03a564c --- /dev/null +++ b/include/polygeist/Kernel/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(KernelOps kernel) \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelDialect.h b/include/polygeist/Kernel/KernelDialect.h new file mode 100644 index 000000000000..6dbf888f97fc --- /dev/null +++ b/include/polygeist/Kernel/KernelDialect.h @@ -0,0 +1,25 @@ +//===- KernelDialect.h - Kernel dialect declaration -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef POLYGEIST_KERNEL_KERNELDIALECT_H +#define POLYGEIST_KERNEL_KERNELDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +namespace mlir { +namespace polygeist { +namespace kernel { + +} // namespace kernel +} // namespace polygeist +} // namespace mlir + +#include "polygeist/Kernel/KernelOpsDialect.h.inc" + +#endif // POLYGEIST_KERNEL_KERNELDIALECT_H \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelDialect.td b/include/polygeist/Kernel/KernelDialect.td new file mode 100644 index 000000000000..68ffc856b65f --- /dev/null +++ b/include/polygeist/Kernel/KernelDialect.td @@ -0,0 +1,36 @@ +//===- KernelDialect.td - Kernel dialect definition -------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef KERNEL_DIALECT +#define KERNEL_DIALECT + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Kernel dialect definition +//===----------------------------------------------------------------------===// + +def Kernel_Dialect : Dialect { + let name = "kernel"; + let cppNamespace = "::mlir::polygeist::kernel"; + let description = [{ + The kernel dialect provides operations for NVIDIA kernel matrix multiplication + routines, including standard and batched GEMM operations. This dialect enables + representation and optimization of high-performance linear algebra kernels + within the Polygeist infrastructure. + }]; +} + +//===----------------------------------------------------------------------===// +// Base class for kernel dialect operations +//===----------------------------------------------------------------------===// + +class Kernel_Op traits = []> : + Op; + +#endif // KERNEL_DIALECT \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelOps.h b/include/polygeist/Kernel/KernelOps.h new file mode 100644 index 000000000000..966ef77d6379 --- /dev/null +++ b/include/polygeist/Kernel/KernelOps.h @@ -0,0 +1,32 @@ +//===- KernelOps.h - Kernel dialect operations ------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef POLYGEIST_KERNEL_KERNELOPS_H +#define POLYGEIST_KERNEL_KERNELOPS_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "polygeist/Kernel/KernelDialect.h" + +namespace mlir { +namespace polygeist { +namespace kernel { + +} // namespace kernel +} // namespace polygeist +} // namespace mlir + +#define GET_OP_CLASSES +#include "polygeist/Kernel/KernelOps.h.inc" + +#endif // POLYGEIST_KERNEL_KERNELOPS_H \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelOps.td b/include/polygeist/Kernel/KernelOps.td new file mode 100644 index 000000000000..90ea3912f2c3 --- /dev/null +++ b/include/polygeist/Kernel/KernelOps.td @@ -0,0 +1,103 @@ +//===- KernelOps.td - Kernel dialect operation definitions -*-- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef KERNEL_OPS +#define KERNEL_OPS + +include "polygeist/Kernel/KernelDialect.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" + +//===----------------------------------------------------------------------===// +// Kernel operation definitions +//===----------------------------------------------------------------------===// + +def Kernel_DefnCollectionOp : Kernel_Op<"defn_collection", []> { + let summary = "Collection of kernel operation definitions"; + let description = [{ + A collection of operation definitions that can be referenced elsewhere. + This operation serves as a container for multiple kernel operation definitions, + enabling modular organization of kernel implementations. + }]; + + let regions = (region SizedRegion<1>:$defns); + + let assemblyFormat = [{ + $defns attr-dict + }]; +} + +def Kernel_DefnOp : Kernel_Op<"defn", [SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Definition of a kernel operation"; + let description = [{ + A definition of a kernel operation with inputs and arbitrary body code. + Can contain either literal CUDA/HIP code or a linalg.generic representation + for high-performance linear algebra operations. + + This operation is particularly useful for defining custom GEMM variants, + batched operations, and other specialized linear algebra kernels. + + Example: + ```mlir + %result = kernel.defn { + ^bb0(%A: memref, %B: memref, + %C: memref, %alpha: f32): + // Kernel implementation + kernel.yield %some_result : tensor + } {name = "custom_gemm"} -> tensor + ``` + }]; + + // TODO: can look into gpu call op (separte namespace) + let arguments = (ins + StrAttr:$name, + TypeAttrOf:$function_type + ); + + let results = (outs Variadic:$results); + + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = [{ + $body attr-dict `->` type($results) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + TypeRange getFunctionResultTypes() { + auto fType = getFunctionType(); + return fType.getResults(); + } + }]; +} + +def Kernel_YieldOp : Kernel_Op<"yield", [Pure, Terminator, + ParentOneOf<["DefnOp"]>]> { + let summary = "Terminator for kernel.defn operation"; + let description = [{ + The `kernel.yield` operation terminates regions within kernel operations. + It optionally returns values from the kernel definition. + }]; + + let arguments = (ins Variadic:$operands); + + let assemblyFormat = [{ + ($operands^ `:` type($operands))? attr-dict + }]; + + let builders = [ + OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]> + ]; + + let hasVerifier = 1; +} + +#endif // KERNEL_OPS \ No newline at end of file diff --git a/lib/polygeist/CMakeLists.txt b/lib/polygeist/CMakeLists.txt index 88aea0de4dd5..b2a410a77872 100644 --- a/lib/polygeist/CMakeLists.txt +++ b/lib/polygeist/CMakeLists.txt @@ -19,3 +19,4 @@ MLIRSCFTransforms ) add_subdirectory(Passes) add_subdirectory(ExecutionEngine) +add_subdirectory(Kernel) diff --git a/lib/polygeist/Kernel/CMakeLists.txt b/lib/polygeist/Kernel/CMakeLists.txt new file mode 100644 index 000000000000..371724504a5e --- /dev/null +++ b/lib/polygeist/Kernel/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRPolygeistKernel + KernelDialect.cpp + KernelOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/polygeist/Kernel + + DEPENDS + MLIRKernelOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRMemRefDialect + MLIRArithDialect + MLIRFuncDialect + MLIRControlFlowInterfaces + MLIRSideEffectInterfaces + MLIRSupport +) \ No newline at end of file diff --git a/lib/polygeist/Kernel/KernelDialect.cpp b/lib/polygeist/Kernel/KernelDialect.cpp new file mode 100644 index 000000000000..0e239ff2565c --- /dev/null +++ b/lib/polygeist/Kernel/KernelDialect.cpp @@ -0,0 +1,33 @@ +//===- KernelDialect.cpp - Kernel dialect implementation --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "polygeist/Kernel/KernelDialect.h" +#include "polygeist/Kernel/KernelOps.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +using namespace mlir; +using namespace mlir::polygeist; +using namespace mlir::polygeist::kernel; + +#include "polygeist/Kernel/KernelOpsDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// Kernel dialect initialization +//===----------------------------------------------------------------------===// + +void KernelDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "polygeist/Kernel/KernelOps.cpp.inc" + >(); +} \ No newline at end of file diff --git a/lib/polygeist/Kernel/KernelOps.cpp b/lib/polygeist/Kernel/KernelOps.cpp new file mode 100644 index 000000000000..7ce6b18998e8 --- /dev/null +++ b/lib/polygeist/Kernel/KernelOps.cpp @@ -0,0 +1,79 @@ +//===- KernelOps.cpp - Kernel dialect operations ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "polygeist/Kernel/KernelOps.h" +#include "polygeist/Kernel/KernelDialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::polygeist; +using namespace mlir::polygeist::kernel; + +//===----------------------------------------------------------------------===// +// DefnOp +//===----------------------------------------------------------------------===// + +LogicalResult DefnOp::verify() { + // Check that the body region has exactly one block + if (!getBody().hasOneBlock()) + return emitOpError("body region must have exactly one block"); + + // The block can have any number of arguments + // No special verification needed for block arguments + + return success(); +} + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +LogicalResult YieldOp::verify() { + // Get the parent DefnOp + auto defnOp = getParentOp(); + if (!defnOp) + return emitOpError("must be nested within a kernel.defn operation"); + + // Get expected result types from the DefnOp's function type + auto functionType = defnOp.getFunctionType(); + auto expectedTypes = functionType.getResults(); + + // Check that the number of operands matches expected results + if (getOperands().size() != expectedTypes.size()) { + return emitOpError("number of yielded values (") + << getOperands().size() << ") does not match expected number of results (" + << expectedTypes.size() << ")"; + } + + // Check that operand types match expected types + for (auto [idx, operand, expectedType] : + llvm::enumerate(getOperands(), expectedTypes)) { + if (operand.getType() != expectedType) { + return emitOpError("yielded value ") << idx << " has type " + << operand.getType() << " but expected " << expectedType; + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "polygeist/Kernel/KernelOps.cpp.inc" \ No newline at end of file diff --git a/tools/polygeist-opt/CMakeLists.txt b/tools/polygeist-opt/CMakeLists.txt index ccfebd421d81..7a61d5b3b7af 100644 --- a/tools/polygeist-opt/CMakeLists.txt +++ b/tools/polygeist-opt/CMakeLists.txt @@ -5,6 +5,7 @@ set(LIBS ${conversion_libs} MLIROptLib MLIRPolygeist + MLIRPolygeistKernel MLIRPolygeistTransforms MLIRFuncAllExtensions ) diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index b5aba75c9264..2a8eada21811 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -33,6 +33,8 @@ #include "polygeist/Dialect.h" #include "polygeist/Passes/Passes.h" +#include "polygeist/Kernel/KernelDialect.h" +#include "polygeist/Kernel/KernelOps.h" using namespace mlir; @@ -62,6 +64,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); mlir::registerpolygeistPasses(); From c0f36d3ed72c28fcc86b512ff7da8170d6508568 Mon Sep 17 00:00:00 2001 From: arjaiswal Date: Wed, 11 Jun 2025 13:12:47 -0700 Subject: [PATCH 54/77] Enabled kernel dialect correctly running on sample IR with kernel defn collection --- generic_solver/cublas_example.mlir | 138 +++++++++----------------- include/polygeist/Kernel/KernelOps.td | 66 +++++++----- lib/polygeist/Kernel/KernelOps.cpp | 63 +++++++----- 3 files changed, 130 insertions(+), 137 deletions(-) diff --git a/generic_solver/cublas_example.mlir b/generic_solver/cublas_example.mlir index f444871c62da..435819e481bc 100644 --- a/generic_solver/cublas_example.mlir +++ b/generic_solver/cublas_example.mlir @@ -3,19 +3,22 @@ module { // Define a collection of kernel operation definitions kernel.defn_collection { // GEMM operation definition with arbitrary code implementation - kernel.defn "gemm" (%A : tensor, %B : tensor, %C : tensor) { + kernel.defn @gemm(%A: tensor, %B: tensor, %C: tensor) { // This could include arbitrary code to implement the GEMM operation // For example, calling into the actual kernel library "some.custom_code"() : () -> () - } : (tensor, tensor, tensor) -> () + kernel.yield + } // GEMM operation definition with linalg.generic representation - kernel.defn "gemm" (%A : tensor, %B : tensor, %C : tensor) { + kernel.defn @gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { + //TODO: move to function arg + //TODO: We can do const prop for alpha and beta for simple matmul match %alpha = arith.constant 1.0 : f32 %beta = arith.constant 0.0 : f32 // Implementation using linalg.generic - linalg.generic { + %result = linalg.generic { indexing_maps = [ affine_map<(i, j, k) -> (i, k)>, // A(i,k) affine_map<(i, j, k) -> (k, j)>, // B(k,j) @@ -30,47 +33,51 @@ module { %scaled_c = arith.mulf %c, %beta : f32 %result = arith.addf %scaled, %scaled_c : f32 linalg.yield %result : f32 - } - } : (tensor, tensor, tensor) -> () + } -> tensor + kernel.yield %result : tensor + } // Batched GEMM operation definition with arbitrary code - kernel.defn "batched_gemm" (%A : tensor, %B : tensor, %C : tensor) { + kernel.defn @batched_gemm(%A2: tensor, %B2: tensor, %C2: tensor) { // This could include arbitrary code to implement the batched GEMM operation "some.custom_code"() : () -> () - } : (tensor, tensor, tensor) -> () + kernel.yield + } // Batched GEMM operation definition with linalg.generic representation - kernel.defn "batched_gemm" (%A : tensor, %B : tensor, %C : tensor) { + kernel.defn @batched_gemm_linalg(%A2: tensor, %B2: tensor, %C2: tensor) { %alpha = arith.constant 1.0 : f32 %beta = arith.constant 0.0 : f32 // Implementation using linalg.generic - linalg.generic { + %result = linalg.generic { indexing_maps = [ affine_map<(b, i, j, k) -> (b, i, k)>, // A(b,i,k) affine_map<(b, i, j, k) -> (b, k, j)>, // B(b,k,j) affine_map<(b, i, j, k) -> (b, i, j)> // C(b,i,j) ], iterator_types = ["parallel", "parallel", "parallel", "reduction"] - } ins(%A, %B : tensor, tensor) - outs(%C : tensor) { + } ins(%A2, %B2 : tensor, tensor) + outs(%C2 : tensor) { ^bb0(%a: f32, %b: f32, %c: f32): %product = arith.mulf %a, %b : f32 %scaled = arith.mulf %product, %alpha : f32 %scaled_c = arith.mulf %c, %beta : f32 %result = arith.addf %scaled, %scaled_c : f32 linalg.yield %result : f32 - } - } : (tensor, tensor, tensor) -> () + } -> tensor + kernel.yield + } // Index of maximum absolute value operation definition with arbitrary code - kernel.defn "iamax" (%X : tensor) { + kernel.defn @iamax(%X: tensor) -> tensor { // This could include arbitrary code to find the index of max absolute value - "some.custom_code"() : () -> () - } : (tensor) -> tensor + %result = "some.custom_code"() : () -> tensor + kernel.yield %result : tensor + } // Index of maximum absolute value operation definition with linalg.generic representation - kernel.defn "iamax" (%X : tensor) { + kernel.defn @iamax_linalg(%X: tensor) -> tensor { // Create an initial tensor to store the result index %c0 = arith.constant 0 : i32 %init = tensor.empty() : tensor @@ -95,17 +102,19 @@ module { %new_idx = arith.select %cmp, %idx, %curr_max_idx : index %result = arith.index_cast %new_idx : index to i32 linalg.yield %result : i32 - } - } : (tensor) -> tensor + } -> tensor + kernel.yield %result : tensor + } // Index of minimum absolute value operation definition with arbitrary code - kernel.defn "iamin" (%X : tensor) { + kernel.defn @iamin(%X: tensor) -> tensor { // This could include arbitrary code to find the index of min absolute value - "some.custom_code"() : () -> () - } : (tensor) -> tensor + %result = "some.custom_code"() : () -> tensor + kernel.yield %result : tensor + } // Index of minimum absolute value operation definition with linalg.generic representation - kernel.defn "iamin" (%X : tensor) { + kernel.defn @iamin_linalg(%X: tensor) -> tensor { // Create an initial tensor to store the result index %c0 = arith.constant 0 : i32 %init = tensor.empty() : tensor @@ -130,17 +139,19 @@ module { %new_idx = arith.select %cmp, %idx, %curr_min_idx : index %result = arith.index_cast %new_idx : index to i32 linalg.yield %result : i32 - } - } : (tensor) -> tensor + } -> tensor + kernel.yield %result : tensor + } // Sum of absolute values operation definition with arbitrary code - kernel.defn "asum" (%X : tensor) { + kernel.defn @asum(%X: tensor) -> tensor { // This could include arbitrary code to compute the sum of absolute values - "some.custom_code"() : () -> () - } : (tensor) -> tensor + %result = "some.custom_code"() : () -> tensor + kernel.yield %result : tensor + } // Sum of absolute values operation definition with linalg.generic representation - kernel.defn "asum" (%X : tensor) { + kernel.defn @asum_linalg(%X: tensor) -> tensor { // Create an initial tensor to store the result sum %c0 = arith.constant 0.0 : f32 %init = tensor.empty() : tensor @@ -159,80 +170,29 @@ module { %abs_val = math.absf %in : f32 %result = arith.addf %abs_val, %out : f32 linalg.yield %result : f32 - } - } : (tensor) -> tensor + } -> tensor + kernel.yield %result : tensor + } // Mathematical definitions (commented, for reference) - // kernel.defn "gemm" (...) { + // kernel.defn @gemm(...) { // C(i,j) += alpha * A(i,k) * B(k,j); // } - // kernel.defn "batched_gemm" (...) { + // kernel.defn @batched_gemm(...) { // C(b,i,j) += alpha * A(b,i,k) * B(b,k,j); // } - // kernel.defn "iamax" (...) { + // kernel.defn @iamax(...) { // result = argmax_i |x_i|; // } - // kernel.defn "iamin" (...) { + // kernel.defn @iamin(...) { // result = argmin_i |x_i|; // } - // kernel.defn "asum" (...) { + // kernel.defn @asum(...) { // result = sum_i |x_i|; // } } - - // Main function showing usage of the operations - func.func @main() { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - // Allocate tensors for matrices - %A = tensor.empty() : tensor<2x128x64xf32> - %B = tensor.empty() : tensor<2x64x256xf32> - %C = tensor.empty() : tensor<2x128x256xf32> - - // Allocate a vector for vector operations - %X = tensor.empty() : tensor<128xf32> - - // Get slices of the batched tensors - %A0 = tensor.extract_slice %A[0, 0, 0][1, 128, 64][1, 1, 1] : tensor<2x128x64xf32> to tensor<128x64xf32> - %B0 = tensor.extract_slice %B[0, 0, 0][1, 64, 256][1, 1, 1] : tensor<2x64x256xf32> to tensor<64x256xf32> - %C0 = tensor.extract_slice %C[0, 0, 0][1, 128, 256][1, 1, 1] : tensor<2x128x256xf32> to tensor<128x256xf32> - - %A1 = tensor.extract_slice %A[1, 0, 0][1, 128, 64][1, 1, 1] : tensor<2x128x64xf32> to tensor<128x64xf32> - %B1 = tensor.extract_slice %B[1, 0, 0][1, 64, 256][1, 1, 1] : tensor<2x64x256xf32> to tensor<64x256xf32> - %C1 = tensor.extract_slice %C[1, 0, 0][1, 128, 256][1, 1, 1] : tensor<2x128x256xf32> to tensor<128x256xf32> - - // Perform individual GEMM operations on slices - // Using kernel.defn operation - kernel.defn(%A0, %B0, %C0) {kernel_name = "gemm"} : - (tensor<128x64xf32>, tensor<64x256xf32>, tensor<128x256xf32>) -> () - - kernel.defn(%A1, %B1, %C1) {kernel_name = "gemm"} : - (tensor<128x64xf32>, tensor<64x256xf32>, tensor<128x256xf32>) -> () - - // Perform batched GEMM operation - // Using kernel.defn operation - kernel.defn(%A, %B, %C) {kernel_name = "batched_gemm"} : - (tensor<2x128x64xf32>, tensor<2x64x256xf32>, tensor<2x128x256xf32>) -> () - - // Perform vector operations - - // Find index of maximum absolute value - %max_idx = kernel.defn(%X) {kernel_name = "iamax"} : - (tensor<128xf32>) -> tensor - - // Find index of minimum absolute value - %min_idx = kernel.defn(%X) {kernel_name = "iamin"} : - (tensor<128xf32>) -> tensor - - // Calculate sum of absolute values - %abs_sum = kernel.defn(%X) {kernel_name = "asum"} : - (tensor<128xf32>) -> tensor - - return - } } \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelOps.td b/include/polygeist/Kernel/KernelOps.td index 90ea3912f2c3..df118618b1a3 100644 --- a/include/polygeist/Kernel/KernelOps.td +++ b/include/polygeist/Kernel/KernelOps.td @@ -12,12 +12,15 @@ include "polygeist/Kernel/KernelDialect.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/OpAsmInterface.td" //===----------------------------------------------------------------------===// // Kernel operation definitions //===----------------------------------------------------------------------===// -def Kernel_DefnCollectionOp : Kernel_Op<"defn_collection", []> { +def Kernel_DefnCollectionOp : Kernel_Op<"defn_collection", [NoTerminator]> { let summary = "Collection of kernel operation definitions"; let description = [{ A collection of operation definitions that can be referenced elsewhere. @@ -32,7 +35,13 @@ def Kernel_DefnCollectionOp : Kernel_Op<"defn_collection", []> { }]; } -def Kernel_DefnOp : Kernel_Op<"defn", [SingleBlockImplicitTerminator<"YieldOp">]> { +def Kernel_DefnOp : Kernel_Op<"defn", [ + AffineScope, + AutomaticAllocationScope, + IsolatedFromAbove, + FunctionOpInterface, + Symbol +]> { let summary = "Definition of a kernel operation"; let description = [{ A definition of a kernel operation with inputs and arbitrary body code. @@ -44,41 +53,54 @@ def Kernel_DefnOp : Kernel_Op<"defn", [SingleBlockImplicitTerminator<"YieldOp">] Example: ```mlir - %result = kernel.defn { - ^bb0(%A: memref, %B: memref, - %C: memref, %alpha: f32): + kernel.defn @custom_gemm(%A: memref, %B: memref, + %C: memref, %alpha: f32) -> tensor { // Kernel implementation kernel.yield %some_result : tensor - } {name = "custom_gemm"} -> tensor + } ``` }]; - // TODO: can look into gpu call op (separte namespace) let arguments = (ins - StrAttr:$name, - TypeAttrOf:$function_type + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); - let results = (outs Variadic:$results); + let regions = (region AnyRegion:$body); - let regions = (region SizedRegion<1>:$body); + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; - let assemblyFormat = [{ - $body attr-dict `->` type($results) - }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; let extraClassDeclaration = [{ - TypeRange getFunctionResultTypes() { - auto fType = getFunctionType(); - return fType.getResults(); - } + /// Returns the argument types of this kernel. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this kernel. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the current operation that is callable. + ::mlir::Region *getCallableRegion() { return &getBody(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return getBody().empty(); } }]; } -def Kernel_YieldOp : Kernel_Op<"yield", [Pure, Terminator, - ParentOneOf<["DefnOp"]>]> { +def Kernel_YieldOp : Kernel_Op<"yield", [Pure, HasParent<"DefnOp">, + MemRefsNormalizable, ReturnLike, Terminator]> { let summary = "Terminator for kernel.defn operation"; let description = [{ The `kernel.yield` operation terminates regions within kernel operations. @@ -87,9 +109,7 @@ def Kernel_YieldOp : Kernel_Op<"yield", [Pure, Terminator, let arguments = (ins Variadic:$operands); - let assemblyFormat = [{ - ($operands^ `:` type($operands))? attr-dict - }]; + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; let builders = [ OpBuilder<(ins), [{ diff --git a/lib/polygeist/Kernel/KernelOps.cpp b/lib/polygeist/Kernel/KernelOps.cpp index 7ce6b18998e8..55c91f5804df 100644 --- a/lib/polygeist/Kernel/KernelOps.cpp +++ b/lib/polygeist/Kernel/KernelOps.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -38,36 +39,48 @@ LogicalResult DefnOp::verify() { return success(); } +ParseResult DefnOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = [](Builder &builder, ArrayRef argTypes, + ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { + return builder.getFunctionType(argTypes, results); + }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void DefnOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// LogicalResult YieldOp::verify() { - // Get the parent DefnOp - auto defnOp = getParentOp(); - if (!defnOp) - return emitOpError("must be nested within a kernel.defn operation"); - - // Get expected result types from the DefnOp's function type - auto functionType = defnOp.getFunctionType(); - auto expectedTypes = functionType.getResults(); - - // Check that the number of operands matches expected results - if (getOperands().size() != expectedTypes.size()) { - return emitOpError("number of yielded values (") - << getOperands().size() << ") does not match expected number of results (" - << expectedTypes.size() << ")"; - } - - // Check that operand types match expected types - for (auto [idx, operand, expectedType] : - llvm::enumerate(getOperands(), expectedTypes)) { - if (operand.getType() != expectedType) { - return emitOpError("yielded value ") << idx << " has type " - << operand.getType() << " but expected " << expectedType; - } - } - + auto defnOp = cast((*this)->getParentOp()); + + // The operand number and types must match the kernel signature. + const auto &results = defnOp.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing kernel (@" + << defnOp.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of yield operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match kernel result type (" + << results[i] << ")" + << " in kernel @" << defnOp.getName(); + return success(); } From 6a673796e6b3929633312bc3da79a3063d9a3ae5 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 11 Jun 2025 18:17:54 -0700 Subject: [PATCH 55/77] Added linalgToKernel pass- compile failure --- include/polygeist/Passes/Passes.h | 11 ++ include/polygeist/Passes/Passes.td | 41 +++++ lib/polygeist/Passes/CMakeLists.txt | 4 + lib/polygeist/Passes/LinalgToKernel.cpp | 199 ++++++++++++++++++++++++ 4 files changed, 255 insertions(+) create mode 100644 lib/polygeist/Passes/LinalgToKernel.cpp diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index 7a95484a2fdb..829e2d75fbbd 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -73,6 +73,8 @@ createGpuSerializeToHsacoPass(StringRef arch, StringRef features, int llvmOptLevel, int hsaOptLevel, std::string rocmPath, bool outputIntermediate); +std::unique_ptr createLinalgToKernelPass(); + void registerGpuSerializeToCubinPass(); void registerGpuSerializeToHsacoPass(); @@ -98,6 +100,11 @@ namespace omp { class OpenMPDialect; } // end namespace omp +namespace polygeist { +namespace kernel { +class KernelDialect; +} // end namespace kernel +} namespace polygeist { class PolygeistDialect; } // end namespace polygeist @@ -130,6 +137,10 @@ namespace linalg { class LinalgDialect; } +namespace tensor { +class TensorDialect; +} + namespace bufferization { class BufferizationDialect; } diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 5b8251c616b8..b994c0d20506 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -255,6 +255,47 @@ def RemoveTrivialUse : Pass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } +def LinalgToKernel : Pass<"linalg-to-kernel", "mlir::ModuleOp"> { + let summary = "Convert linalg.generic operations to kernel operations by matching with kernel.defn patterns"; + let description = [{ + This pass matches linalg.generic operations against patterns defined in + kernel.defn_collection operations and converts them to the corresponding + specialized kernel operations (e.g., kernel.gemm, kernel.batched_gemm). + + The pass performs semantic matching of linalg.generic operations by: + - Comparing indexing maps and iterator types + - Matching the operation structure within regions + - Checking input/output operand counts + + Example transformation: + ```mlir + // Input: linalg.generic performing matrix multiplication + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %mul = arith.mulf %a, %b : f32 + %add = arith.addf %mul, %c : f32 + linalg.yield %add : f32 + } -> tensor + + // Output: Specialized kernel operation + %result = kernel.gemm %C, %A, %B, %alpha, %beta : tensor + ``` + }]; + let constructor = "mlir::polygeist::createLinalgToKernelPass()"; + let dependentDialects = [ + "linalg::LinalgDialect", + "polygeist::kernel::KernelDialect", + "tensor::TensorDialect", + "arith::ArithDialect", + ]; +} + def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { let summary = "Convert scalar and vector operations from the Standard to the " "LLVM dialect"; diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index ae74300af7a1..07a559ae00e8 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms RemoveIterArgs.cpp RaiseToLinalg.cpp LinalgDebufferize.cpp + LinalgToKernel.cpp ParallelLower.cpp TrivialUse.cpp ConvertPolygeistToLLVM.cpp @@ -45,15 +46,18 @@ add_mlir_dialect_library(MLIRPolygeistTransforms MLIRGPUToNVVMTransforms MLIRIR MLIRLLVMDialect + MLIRLinalgDialect MLIRMathDialect MLIRMathToLLVM MLIRMemRefDialect MLIRNVVMDialect MLIRPass MLIRPolygeist + MLIRPolygeistKernel MLIRSideEffectInterfaces MLIRSCFToControlFlow MLIRTargetLLVMIRImport + MLIRTensorDialect MLIRTransformUtils MLIRGPUToROCDLTransforms MLIRControlFlowToLLVM diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp new file mode 100644 index 000000000000..24dab3e6d53a --- /dev/null +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -0,0 +1,199 @@ +//===- LinalgToKernel.cpp - Pattern to match linalg.generic with kernel.defn ------===// +// +// This file implements a pattern to rewrite linalg.generic operations to kernel +// operations by matching against patterns defined in kernel.defn_collection. +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" +#include "polygeist/Kernel/KernelDialect.h" +#include "polygeist/Kernel/KernelOps.h" +#include "polygeist/Passes/Passes.h" + +#include +#include + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::polygeist; +using namespace mlir::polygeist::kernel; + +namespace { + +// Helper function to check if two regions are structurally equivalent +bool areRegionsEquivalent(Region &first, Region &second) { + // Compare number of blocks + if (first.getBlocks().size() != second.getBlocks().size()) + return false; + + // Compare corresponding blocks + for (auto blockPair : llvm::zip(first.getBlocks(), second.getBlocks())) { + Block &firstBlock = std::get<0>(blockPair); + Block &secondBlock = std::get<1>(blockPair); + + // Compare number of arguments + if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) + return false; + + // Compare argument types + for (auto argPair : llvm::zip(firstBlock.getArguments(), + secondBlock.getArguments())) { + if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) + return false; + } + + // Compare operations (simplified - real implementation would be more complex) + if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) + return false; + + // For a full implementation, you'd need more sophisticated operation comparison + // based on operands, attributes, and result types + } + + return true; +} + +// Helper to check if indexing maps are equivalent +bool areIndexingMapsEquivalent(ArrayAttr firstMaps, ArrayAttr secondMaps) { + if (firstMaps.size() != secondMaps.size()) + return false; + + for (auto mapPair : llvm::zip(firstMaps, secondMaps)) { + auto firstMap = std::get<0>(mapPair).cast().getValue(); + auto secondMap = std::get<1>(mapPair).cast().getValue(); + + if (firstMap != secondMap) + return false; + } + + return true; +} + +// Helper to check if iterator types are equivalent +bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { + if (firstTypes.size() != secondTypes.size()) + return false; + + for (auto typePair : llvm::zip(firstTypes, secondTypes)) { + auto firstType = std::get<0>(typePair).cast().getValue(); + auto secondType = std::get<1>(typePair).cast().getValue(); + + if (firstType != secondType) + return false; + } + + return true; +} + +// Check if a linalg.generic operation matches a kernel.defn in a collection +FailureOr matchGenericWithDefn( + GenericOp genericOp, + kernel::DefnCollectionOp collectionOp) { + + // Get attributes from the generic operation + ArrayAttr indexingMaps = genericOp.getIndexingMapsAttr(); + ArrayAttr iteratorTypes = genericOp.getIteratorTypesAttr(); + unsigned numInputs = genericOp.getNumDpsInputs(); + unsigned numOutputs = genericOp.getNumDpsInits(); + + // Walk through each defn in the collection + for (Operation &op : collectionOp.getDefns()) { + auto defnOp = cast(op); + StringRef opName = defnOp.getSymName(); + + // Check for linalg.generic in the defn's body + bool foundMatch = false; + defnOp.getBody().walk([&](GenericOp candidateOp) { + // Skip if already found a match + if (foundMatch) + return; + + // Check if this linalg.generic matches our target + if (candidateOp.getNumDpsInputs() == numInputs && + candidateOp.getNumDpsInits() == numOutputs && + areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && + areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && + areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) { + foundMatch = true; + } + }); + + if (foundMatch) + return opName; + } + + return failure(); +} + +// Rewrite pattern to convert linalg.generic to kernel ops +class LinalgGenericToKernelPattern : public OpRewritePattern { +public: + LinalgGenericToKernelPattern(MLIRContext *context, + kernel::DefnCollectionOp collectionOp) + : OpRewritePattern(context), collectionOp(collectionOp) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + // Try to match with a defn in the collection + auto matchResult = matchGenericWithDefn(genericOp, collectionOp); + if (failed(matchResult)) + return failure(); + + StringRef opName = *matchResult; + + // For now, just emit a diagnostic indicating we found a match + // In the future, this would create the appropriate kernel operation + genericOp.emitRemark() << "Matched linalg.generic with kernel pattern: " << opName; + + // TODO: Create the appropriate kernel operation based on the matched pattern + // This would require implementing kernel operations in the kernel dialect + + return success(); + } + +private: + kernel::DefnCollectionOp collectionOp; +}; + +// Pass to apply the rewrite pattern +struct LinalgToKernelPass : public LinalgToKernelBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + + // Find the kernel.defn_collection in the module + kernel::DefnCollectionOp collectionOp; + module.walk([&](kernel::DefnCollectionOp op) { + collectionOp = op; + return WalkResult::interrupt(); + }); + + if (!collectionOp) { + module.emitError("No kernel.defn_collection found in module"); + return signalPassFailure(); + } + + // Apply the rewrite pattern + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), collectionOp); + + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir::polygeist { + +// Create a pass to convert linalg.generic to kernel +std::unique_ptr createLinalgToKernelPass() { + return std::make_unique(); +} + +} // namespace mlir::polygeist \ No newline at end of file From 7f9d00fedb2c95cdb35c725713e44c627f26ea07 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Wed, 11 Jun 2025 22:46:10 -0700 Subject: [PATCH 56/77] Working pattern matching and replacement for linalg generics --- generic_solver/cublas_example.mlir | 41 ++++++++ include/polygeist/Kernel/KernelOps.td | 79 ++++++++++++++- lib/polygeist/Kernel/KernelOps.cpp | 58 +++++++++++ lib/polygeist/Passes/LinalgToKernel.cpp | 128 +++++++++++++++++++----- 4 files changed, 278 insertions(+), 28 deletions(-) diff --git a/generic_solver/cublas_example.mlir b/generic_solver/cublas_example.mlir index 435819e481bc..8c6ef4b52e20 100644 --- a/generic_solver/cublas_example.mlir +++ b/generic_solver/cublas_example.mlir @@ -2,6 +2,27 @@ module { // Define a collection of kernel operation definitions kernel.defn_collection { + + // GEMM operation definition with linalg.generic representation + kernel.defn @simple_gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i, j, k) -> (i, k)>, // A(i,k) + affine_map<(i, j, k) -> (k, j)>, // B(k,j) + affine_map<(i, j, k) -> (i, j)> // C(i,j) + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + // GEMM operation definition with arbitrary code implementation kernel.defn @gemm(%A: tensor, %B: tensor, %C: tensor) { // This could include arbitrary code to implement the GEMM operation @@ -173,6 +194,26 @@ module { } -> tensor kernel.yield %result : tensor } + + //Func that uses simple gemm + func.func @simple_gemm(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i, j, k) -> (i, k)>, // A(i,k) + affine_map<(i, j, k) -> (k, j)>, // B(k,j) + affine_map<(i, j, k) -> (i, j)> // C(i,j) + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } // Mathematical definitions (commented, for reference) // kernel.defn @gemm(...) { diff --git a/include/polygeist/Kernel/KernelOps.td b/include/polygeist/Kernel/KernelOps.td index df118618b1a3..aa5c758cf179 100644 --- a/include/polygeist/Kernel/KernelOps.td +++ b/include/polygeist/Kernel/KernelOps.td @@ -68,7 +68,7 @@ def Kernel_DefnOp : Kernel_Op<"defn", [ OptionalAttr:$arg_attrs, OptionalAttr:$res_attrs ); - + let regions = (region AnyRegion:$body); let builders = [OpBuilder<(ins @@ -99,6 +99,83 @@ def Kernel_DefnOp : Kernel_Op<"defn", [ }]; } +//===----------------------------------------------------------------------===// +// LaunchOp +//===----------------------------------------------------------------------===// + +def Kernel_LaunchOp : Kernel_Op<"launch", + [CallOpInterface, MemRefsNormalizable, + DeclareOpInterfaceMethods]> { + let summary = "kernel launch operation"; + let description = [{ + The `kernel.launch` operation represents a launch of a kernel that is + within the same symbol scope as the launch. The operands and result types of + the launch must match the specified kernel type. The kernel is encoded as a + symbol reference attribute named "kernel". + + Example: + + ```mlir + %result = kernel.launch @custom_gemm(%A, %B, %C, %alpha) : (memref, memref, memref, f32) -> tensor + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$kernel, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "DefnOp":$kernel, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("kernel", SymbolRefAttr::get(kernel)); + $_state.addTypes(kernel.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$kernel, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("kernel", kernel); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$kernel, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(kernel), results, operands); + }]>, + OpBuilder<(ins "StringRef":$kernel, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), kernel), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getKernelType(); + + /// Get the argument operands to the launched kernel. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the kernel of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("kernel"); + } + + /// Set the kernel for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("kernel", callee.get()); + } + }]; + + let assemblyFormat = [{ + $kernel `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + def Kernel_YieldOp : Kernel_Op<"yield", [Pure, HasParent<"DefnOp">, MemRefsNormalizable, ReturnLike, Terminator]> { let summary = "Terminator for kernel.defn operation"; diff --git a/lib/polygeist/Kernel/KernelOps.cpp b/lib/polygeist/Kernel/KernelOps.cpp index 55c91f5804df..8ad84f79e6ea 100644 --- a/lib/polygeist/Kernel/KernelOps.cpp +++ b/lib/polygeist/Kernel/KernelOps.cpp @@ -84,6 +84,64 @@ LogicalResult YieldOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// LaunchOp +//===----------------------------------------------------------------------===// + +FunctionType LaunchOp::getKernelType() { + // Get the kernel symbol reference + auto kernelAttr = (*this)->getAttrOfType("kernel"); + if (!kernelAttr) + return nullptr; + + // Look up the kernel DefnOp in the symbol table + auto *symbolTableOp = (*this)->getParentWithTrait(); + if (!symbolTableOp) + return nullptr; + + auto kernelOp = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symbolTableOp, kernelAttr)); + if (!kernelOp) + return nullptr; + + return kernelOp.getFunctionType(); +} + +LogicalResult LaunchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the kernel attribute was specified. + auto kernelAttr = (*this)->getAttrOfType("kernel"); + if (!kernelAttr) + return emitOpError("requires a 'kernel' symbol reference attribute"); + + // Check that the kernel symbol exists and is a DefnOp. + auto kernelOp = symbolTable.lookupNearestSymbolFrom(*this, kernelAttr); + if (!kernelOp) + return emitOpError() << "'" << kernelAttr.getValue() + << "' does not reference a valid kernel"; + + // Verify that the operand and result types match the kernel signature. + auto kernelType = kernelOp.getFunctionType(); + if (kernelType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for kernel"); + + for (unsigned i = 0, e = kernelType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != kernelType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << kernelType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (kernelType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for kernel"); + + for (unsigned i = 0, e = kernelType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != kernelType.getResult(i)) + return emitOpError("result type mismatch: expected result type ") + << kernelType.getResult(i) << ", but provided " + << getResult(i).getType() << " for result number " << i; + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op definitions //===----------------------------------------------------------------------===// diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp index 24dab3e6d53a..2f8399a48759 100644 --- a/lib/polygeist/Passes/LinalgToKernel.cpp +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -81,8 +81,8 @@ bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { return false; for (auto typePair : llvm::zip(firstTypes, secondTypes)) { - auto firstType = std::get<0>(typePair).cast().getValue(); - auto secondType = std::get<1>(typePair).cast().getValue(); + auto firstType = std::get<0>(typePair).cast().getValue(); + auto secondType = std::get<1>(typePair).cast().getValue(); if (firstType != secondType) return false; @@ -102,32 +102,43 @@ FailureOr matchGenericWithDefn( unsigned numInputs = genericOp.getNumDpsInputs(); unsigned numOutputs = genericOp.getNumDpsInits(); + // Variables to capture the match result + StringRef matchedOpName; + + SmallVector defnOps; + + collectionOp.walk([&](kernel::DefnOp defnOp) { + defnOps.push_back(defnOp); + }); + + bool foundMatch = false; + // Walk through each defn in the collection - for (Operation &op : collectionOp.getDefns()) { - auto defnOp = cast(op); - StringRef opName = defnOp.getSymName(); + for (auto defnOp : defnOps) { + StringRef opName = defnOp.getSymName(); // Check for linalg.generic in the defn's body - bool foundMatch = false; - defnOp.getBody().walk([&](GenericOp candidateOp) { - // Skip if already found a match - if (foundMatch) - return; - - // Check if this linalg.generic matches our target - if (candidateOp.getNumDpsInputs() == numInputs && - candidateOp.getNumDpsInits() == numOutputs && - areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && - areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && - areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) { - foundMatch = true; - } + GenericOp candidateOp; + + defnOp.walk([&](GenericOp genericOp) { + candidateOp = genericOp; //TODO: Add checks to make sure there is only single linalg.generic in the defn }); - if (foundMatch) - return opName; + // Check if this linalg.generic matches our target + if (candidateOp.getNumDpsInputs() == numInputs && + candidateOp.getNumDpsInits() == numOutputs && + areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && + areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && + areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) { + foundMatch = true; + matchedOpName = opName; + } + + if (foundMatch) { + return matchedOpName; + } } - + return failure(); } @@ -140,6 +151,15 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { + + auto module = genericOp->getParentOfType(); + //Check if the parent of the generic op is a kernel.defn + if (auto parentOp = genericOp->getParentOp()) { + if (isa(parentOp)) { + return failure(); + } + } + // Try to match with a defn in the collection auto matchResult = matchGenericWithDefn(genericOp, collectionOp); if (failed(matchResult)) @@ -147,12 +167,66 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { StringRef opName = *matchResult; - // For now, just emit a diagnostic indicating we found a match - // In the future, this would create the appropriate kernel operation - genericOp.emitRemark() << "Matched linalg.generic with kernel pattern: " << opName; + // Find the matched kernel.defn operation + kernel::DefnOp matchedDefnOp; + // Use const_cast to work around the const issue + const_cast(collectionOp).walk([&](kernel::DefnOp defnOp) { + if (defnOp.getSymName() == opName) { + matchedDefnOp = defnOp; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (!matchedDefnOp) { + return failure(); + } + + // Check if the kernel.defn already exists in the target module + kernel::DefnOp existingDefn; + module.walk([&](kernel::DefnOp defnOp) { + if (defnOp.getSymName() == opName) { + // Check if this defn is inside a defn_collection (template) or at module level (callable) + if (!defnOp->getParentOfType()) { + existingDefn = defnOp; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + + // If the kernel.defn doesn't exist in the module, copy it + if (!existingDefn) { + // Clone the matched kernel.defn operation + rewriter.setInsertionPointToStart(module.getBody()); + auto clonedDefn = rewriter.clone(*matchedDefnOp.getOperation()); + (void)clonedDefn; // Suppress unused variable warning + } + + // Create kernel.launch operation to replace the genericOp + Location loc = genericOp.getLoc(); + + // Set insertion point to the genericOp location + rewriter.setInsertionPoint(genericOp); + + // Get operands from the generic operation (inputs and outputs) + SmallVector operands; + operands.append(genericOp.getInputs().begin(), genericOp.getInputs().end()); + operands.append(genericOp.getOutputs().begin(), genericOp.getOutputs().end()); + + // Get result types from the generic operation + TypeRange resultTypes = genericOp.getResultTypes(); + + // Create the kernel.launch operation + auto launchOp = rewriter.create( + loc, + resultTypes, + opName, + operands + ); - // TODO: Create the appropriate kernel operation based on the matched pattern - // This would require implementing kernel operations in the kernel dialect + // Replace the generic operation with the launch operation + rewriter.replaceOp(genericOp, launchOp.getResults()); return success(); } From d765bb90332eb92060fa103a5a0ef2fff53d750e Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 12 Jun 2025 16:07:16 -0700 Subject: [PATCH 57/77] Partial changes for different files for kernel and input --- generic_solver/cublas_example.mlir | 82 +++++++++--------- generic_solver/kernel_library_simple.mlir | 101 ++++++++++++++++++++++ generic_solver/test_input_simple.mlir | 71 +++++++++++++++ include/polygeist/Passes/Passes.h | 1 + include/polygeist/Passes/Passes.td | 6 ++ lib/polygeist/Passes/LinalgToKernel.cpp | 99 +++++++++++++++++++-- 6 files changed, 312 insertions(+), 48 deletions(-) create mode 100644 generic_solver/kernel_library_simple.mlir create mode 100644 generic_solver/test_input_simple.mlir diff --git a/generic_solver/cublas_example.mlir b/generic_solver/cublas_example.mlir index 8c6ef4b52e20..84a77dab9544 100644 --- a/generic_solver/cublas_example.mlir +++ b/generic_solver/cublas_example.mlir @@ -3,26 +3,6 @@ module { // Define a collection of kernel operation definitions kernel.defn_collection { - // GEMM operation definition with linalg.generic representation - kernel.defn @simple_gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { - // Implementation using linalg.generic - %result = linalg.generic { - indexing_maps = [ - affine_map<(i, j, k) -> (i, k)>, // A(i,k) - affine_map<(i, j, k) -> (k, j)>, // B(k,j) - affine_map<(i, j, k) -> (i, j)> // C(i,j) - ], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B : tensor, tensor) - outs(%C : tensor) { - ^bb0(%a: f32, %b: f32, %c: f32): - %product = arith.mulf %a, %b : f32 - %result = arith.addf %product, %c : f32 - linalg.yield %result : f32 - } -> tensor - kernel.yield %result : tensor - } - // GEMM operation definition with arbitrary code implementation kernel.defn @gemm(%A: tensor, %B: tensor, %C: tensor) { // This could include arbitrary code to implement the GEMM operation @@ -89,6 +69,27 @@ module { } -> tensor kernel.yield } + + // GEMM operation definition with linalg.generic representation + kernel.defn @simple_gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i, j, k) -> (i, k)>, // A(i,k) + affine_map<(i, j, k) -> (k, j)>, // B(k,j) + affine_map<(i, j, k) -> (i, j)> // C(i,j) + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + // Index of maximum absolute value operation definition with arbitrary code kernel.defn @iamax(%X: tensor) -> tensor { @@ -195,26 +196,6 @@ module { kernel.yield %result : tensor } - //Func that uses simple gemm - func.func @simple_gemm(%A: tensor, %B: tensor, %C: tensor) -> tensor { - // Implementation using linalg.generic - %result = linalg.generic { - indexing_maps = [ - affine_map<(i, j, k) -> (i, k)>, // A(i,k) - affine_map<(i, j, k) -> (k, j)>, // B(k,j) - affine_map<(i, j, k) -> (i, j)> // C(i,j) - ], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B : tensor, tensor) - outs(%C : tensor) { - ^bb0(%a: f32, %b: f32, %c: f32): - %product = arith.mulf %a, %b : f32 - %result = arith.addf %product, %c : f32 - linalg.yield %result : f32 - } -> tensor - return %result : tensor - } - // Mathematical definitions (commented, for reference) // kernel.defn @gemm(...) { // C(i,j) += alpha * A(i,k) * B(k,j); @@ -236,4 +217,25 @@ module { // result = sum_i |x_i|; // } } + + //Func that uses simple gemm + func.func @simple_gemm(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i, j, k) -> (i, k)>, // A(i,k) + affine_map<(i, j, k) -> (k, j)>, // B(k,j) + affine_map<(i, j, k) -> (i, j)> // C(i,j) + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } + } \ No newline at end of file diff --git a/generic_solver/kernel_library_simple.mlir b/generic_solver/kernel_library_simple.mlir new file mode 100644 index 000000000000..7b31faa86aa6 --- /dev/null +++ b/generic_solver/kernel_library_simple.mlir @@ -0,0 +1,101 @@ +// Kernel Library - Reusable kernel definitions +// This file contains a collection of kernel definitions that can be loaded +// by the linalg-to-kernel pass and applied to different MLIR modules. + +module { + // Collection of kernel operation definitions + kernel.defn_collection { + + // Simple GEMM operation definition with linalg.generic representation + kernel.defn @simple_gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Simple matrix multiplication: C = A * B + C + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // Scaled GEMM operation definition with alpha and beta coefficients + kernel.defn @gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { + %alpha = arith.constant 1.0 : f32 + %beta = arith.constant 0.0 : f32 + + // GEMM with scaling: C = alpha * A * B + beta * C + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %scaled = arith.mulf %product, %alpha : f32 + %scaled_c = arith.mulf %c, %beta : f32 + %result = arith.addf %scaled, %scaled_c : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // Sum of absolute values operation (ASUM) + kernel.defn @asum_linalg(%X: tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // Sum of absolute values: result = sum_i |x_i| + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%fill : tensor) { + ^bb0(%in: f32, %out: f32): + %abs_val = math.absf %in : f32 + %result = arith.addf %abs_val, %out : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // Vector dot product + kernel.defn @dot_linalg(%X: tensor, %Y: tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // Dot product: result = sum_i x_i * y_i + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X, %Y : tensor, tensor) + outs(%fill : tensor) { + ^bb0(%x: f32, %y: f32, %out: f32): + %product = arith.mulf %x, %y : f32 + %result = arith.addf %product, %out : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + } +} \ No newline at end of file diff --git a/generic_solver/test_input_simple.mlir b/generic_solver/test_input_simple.mlir new file mode 100644 index 000000000000..8fa0e6df4edf --- /dev/null +++ b/generic_solver/test_input_simple.mlir @@ -0,0 +1,71 @@ +// Test input file - contains linalg.generic operations to be matched +// This file does NOT contain kernel.defn_collection - those will be loaded externally + +module { + // Function that performs simple matrix multiplication + func.func @simple_gemm(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // This linalg.generic should match @simple_gemm_linalg from kernel_library.mlir + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } + + // Function that computes sum of absolute values + func.func @compute_asum(%X: tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // This linalg.generic should match @asum_linalg from kernel_library.mlir + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%fill : tensor) { + ^bb0(%in: f32, %out: f32): + %abs_val = math.absf %in : f32 + %result = arith.addf %abs_val, %out : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } + + // Function that computes dot product + func.func @compute_dot(%X: tensor, %Y: tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // This linalg.generic should match @dot_linalg from kernel_library.mlir + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X, %Y : tensor, tensor) + outs(%fill : tensor) { + ^bb0(%x: f32, %y: f32, %out: f32): + %product = arith.mulf %x, %y : f32 + %result = arith.addf %product, %out : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } +} \ No newline at end of file diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index 829e2d75fbbd..c1cea4c2ec72 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -74,6 +74,7 @@ createGpuSerializeToHsacoPass(StringRef arch, StringRef features, std::string rocmPath, bool outputIntermediate); std::unique_ptr createLinalgToKernelPass(); +std::unique_ptr createLinalgToKernelPass(const std::string& kernelLibraryPath); void registerGpuSerializeToCubinPass(); void registerGpuSerializeToHsacoPass(); diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index b994c0d20506..4945396d6178 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -294,6 +294,12 @@ def LinalgToKernel : Pass<"linalg-to-kernel", "mlir::ModuleOp"> { "tensor::TensorDialect", "arith::ArithDialect", ]; + let options = [ + Option<"kernelLibraryPath", "kernel-library-path", "std::string", + /*default=*/"\"\"", + "Path to external MLIR file containing kernel.defn_collection definitions. " + "If empty, looks for kernel.defn_collection in the input module."> + ]; } def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp index 2f8399a48759..4b91330801c1 100644 --- a/lib/polygeist/Passes/LinalgToKernel.cpp +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -11,7 +11,11 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/FileUtilities.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" #include "polygeist/Kernel/KernelDialect.h" #include "polygeist/Kernel/KernelOps.h" #include "polygeist/Passes/Passes.h" @@ -123,6 +127,10 @@ FailureOr matchGenericWithDefn( defnOp.walk([&](GenericOp genericOp) { candidateOp = genericOp; //TODO: Add checks to make sure there is only single linalg.generic in the defn }); + + if(!candidateOp) { + continue; + } // Check if this linalg.generic matches our target if (candidateOp.getNumDpsInputs() == numInputs && @@ -237,19 +245,86 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { // Pass to apply the rewrite pattern struct LinalgToKernelPass : public LinalgToKernelBase { + using LinalgToKernelBase::LinalgToKernelBase; + + // Constructor that allows setting the kernel library path + LinalgToKernelPass() = default; + LinalgToKernelPass(const std::string& libraryPath) : externalLibraryPath(libraryPath) {} + void runOnOperation() override { ModuleOp module = getOperation(); - // Find the kernel.defn_collection in the module kernel::DefnCollectionOp collectionOp; - module.walk([&](kernel::DefnCollectionOp op) { - collectionOp = op; - return WalkResult::interrupt(); - }); - if (!collectionOp) { - module.emitError("No kernel.defn_collection found in module"); - return signalPassFailure(); + // Determine which path to use for kernel library + std::string effectiveLibraryPath = externalLibraryPath; + // If no external path was provided via constructor, try the command line option + if (effectiveLibraryPath.empty()) { + effectiveLibraryPath = std::string(kernelLibraryPath); + } + + // Debug output + llvm::errs() << "DEBUG: externalLibraryPath = '" << externalLibraryPath << "'\n"; + llvm::errs() << "DEBUG: kernelLibraryPath = '" << std::string(kernelLibraryPath) << "'\n"; + llvm::errs() << "DEBUG: effectiveLibraryPath = '" << effectiveLibraryPath << "'\n"; + + // Check if we should load kernel definitions from an external file + if (!effectiveLibraryPath.empty()) { + //llvm::errs() << "DEBUG: Loading kernel definitions from external file: " << effectiveLibraryPath << "\n"; + // Load kernel definitions from external file + std::string errorMessage; + auto memoryBuffer = mlir::openInputFile(effectiveLibraryPath, &errorMessage); + if (!memoryBuffer) { + module.emitError("Failed to open kernel library file: ") << effectiveLibraryPath + << " - " << errorMessage; + return signalPassFailure(); + } + + // Parse the external file + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); + + auto externalModule = mlir::parseSourceFile(sourceMgr, &getContext()); + if (!externalModule) { + module.emitError("Failed to parse kernel library file: ") << effectiveLibraryPath; + return signalPassFailure(); + } + + // Debug: Print the loaded external module + //llvm::errs() << "DEBUG: Successfully loaded external module:\n"; + //externalModule->print(llvm::errs()); + //llvm::errs() << "\n"; + + // Find the kernel.defn_collection in the external module + externalModule->walk([&](kernel::DefnCollectionOp op) { + collectionOp = op; + llvm::errs() << "DEBUG: Found kernel.defn_collection in external module\n"; + return WalkResult::interrupt(); + }); + + if (!collectionOp) { + module.emitError("No kernel.defn_collection found in external kernel library: ") + << effectiveLibraryPath; + return signalPassFailure(); + } + + // Debug: Print the found collection + //llvm::errs() << "DEBUG: kernel.defn_collection contents:\n"; + //collectionOp.print(llvm::errs()); + //llvm::errs() << "\n"; + } else { + // Find the kernel.defn_collection in the current module (original behavior) + module.walk([&](kernel::DefnCollectionOp op) { + collectionOp = op; + return WalkResult::interrupt(); + }); + + if (!collectionOp) { + module.emitError("No kernel.defn_collection found in module. " + "Either include one in the input module or specify " + "--kernel-library-path to load from external file."); + return signalPassFailure(); + } } // Apply the rewrite pattern @@ -259,6 +334,9 @@ struct LinalgToKernelPass : public LinalgToKernelBase { if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) return signalPassFailure(); } + +private: + std::string externalLibraryPath; }; } // namespace @@ -270,4 +348,9 @@ std::unique_ptr createLinalgToKernelPass() { return std::make_unique(); } +// Create a pass to convert linalg.generic to kernel with kernel library path +std::unique_ptr createLinalgToKernelPass(const std::string& kernelLibraryPath) { + return std::make_unique(kernelLibraryPath); +} + } // namespace mlir::polygeist \ No newline at end of file From 15ef84eb3a5987b8223b97e5471bd1e93888cfab Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 12 Jun 2025 17:09:11 -0700 Subject: [PATCH 58/77] Crash fix --- lib/polygeist/Passes/LinalgToKernel.cpp | 27 +++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp index 4b91330801c1..420c985df71b 100644 --- a/lib/polygeist/Passes/LinalgToKernel.cpp +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -111,6 +111,10 @@ FailureOr matchGenericWithDefn( SmallVector defnOps; + //llvm::errs() << "DEBUG: kernel.defn_collection contents:\n"; + //llvm::errs() << collectionOp; + //llvm::errs() << collectionOp.getOperation(); + //llvm::errs() << "\n"; collectionOp.walk([&](kernel::DefnOp defnOp) { defnOps.push_back(defnOp); }); @@ -254,8 +258,8 @@ struct LinalgToKernelPass : public LinalgToKernelBase { void runOnOperation() override { ModuleOp module = getOperation(); - kernel::DefnCollectionOp collectionOp; - + kernel::DefnCollectionOp collectionOp = nullptr; + OwningOpRef externalModule; // Determine which path to use for kernel library std::string effectiveLibraryPath = externalLibraryPath; // If no external path was provided via constructor, try the command line option @@ -263,10 +267,10 @@ struct LinalgToKernelPass : public LinalgToKernelBase { effectiveLibraryPath = std::string(kernelLibraryPath); } - // Debug output - llvm::errs() << "DEBUG: externalLibraryPath = '" << externalLibraryPath << "'\n"; - llvm::errs() << "DEBUG: kernelLibraryPath = '" << std::string(kernelLibraryPath) << "'\n"; - llvm::errs() << "DEBUG: effectiveLibraryPath = '" << effectiveLibraryPath << "'\n"; + //// Debug output + //llvm::errs() << "DEBUG: externalLibraryPath = '" << externalLibraryPath << "'\n"; + //llvm::errs() << "DEBUG: kernelLibraryPath = '" << std::string(kernelLibraryPath) << "'\n"; + //llvm::errs() << "DEBUG: effectiveLibraryPath = '" << effectiveLibraryPath << "'\n"; // Check if we should load kernel definitions from an external file if (!effectiveLibraryPath.empty()) { @@ -284,7 +288,7 @@ struct LinalgToKernelPass : public LinalgToKernelBase { llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); - auto externalModule = mlir::parseSourceFile(sourceMgr, &getContext()); + externalModule = mlir::parseSourceFile(sourceMgr, &getContext()); if (!externalModule) { module.emitError("Failed to parse kernel library file: ") << effectiveLibraryPath; return signalPassFailure(); @@ -310,7 +314,8 @@ struct LinalgToKernelPass : public LinalgToKernelBase { // Debug: Print the found collection //llvm::errs() << "DEBUG: kernel.defn_collection contents:\n"; - //collectionOp.print(llvm::errs()); + //llvm::errs() << collectionOp; + //llvm::errs() << collectionOp.getOperation(); //llvm::errs() << "\n"; } else { // Find the kernel.defn_collection in the current module (original behavior) @@ -330,6 +335,12 @@ struct LinalgToKernelPass : public LinalgToKernelBase { // Apply the rewrite pattern RewritePatternSet patterns(&getContext()); patterns.add(&getContext(), collectionOp); + + //llvm::errs() << "DEBUG: kernel.defn_collection contents:\n"; + //llvm::errs() << collectionOp.getOperation(); + //llvm::errs() << "\n"; + //llvm::errs() << collectionOp; + //llvm::errs() << "\n"; if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) return signalPassFailure(); From 44fed6c461c87a76af14f2d7a657cc4ec44d8ce6 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 26 Jun 2025 08:28:47 -0700 Subject: [PATCH 59/77] Improved lib --- generic_solver/example.mlir | 24 +++ ...ublas_example.mlir => kernel_library.mlir} | 184 ++++++++---------- generic_solver/kernel_library_simple.mlir | 71 +++++-- 3 files changed, 162 insertions(+), 117 deletions(-) create mode 100644 generic_solver/example.mlir rename generic_solver/{cublas_example.mlir => kernel_library.mlir} (76%) diff --git a/generic_solver/example.mlir b/generic_solver/example.mlir new file mode 100644 index 000000000000..68ae2c73a3be --- /dev/null +++ b/generic_solver/example.mlir @@ -0,0 +1,24 @@ +//RUN: polygeist-opt --linalg-to-kernel="kernel-library-path=/home/arjaiswal/Polygeist/generic_solver/kernel_library_simple.mlir" -allow-unregistered-dialect generic_solver/example.mlir +// Example MLIR module demonstrating kernel operations and their linalg.generic representations +module { + //Func that uses simple gemm + func.func @simple_gemm(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i, j, k) -> (i, k)>, // A(i,k) + affine_map<(i, j, k) -> (k, j)>, // B(k,j) + affine_map<(i, j, k) -> (i, j)> // C(i,j) + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } + +} \ No newline at end of file diff --git a/generic_solver/cublas_example.mlir b/generic_solver/kernel_library.mlir similarity index 76% rename from generic_solver/cublas_example.mlir rename to generic_solver/kernel_library.mlir index 84a77dab9544..d8e70186f618 100644 --- a/generic_solver/cublas_example.mlir +++ b/generic_solver/kernel_library.mlir @@ -1,29 +1,42 @@ -// Example MLIR module demonstrating kernel operations and their linalg.generic representations +// Kernel Library - Reusable kernel definitions +// This file contains a collection of kernel definitions that can be loaded +// by the linalg-to-kernel pass and applied to different MLIR modules. + module { - // Define a collection of kernel operation definitions + // Collection of kernel operation definitions kernel.defn_collection { - // GEMM operation definition with arbitrary code implementation - kernel.defn @gemm(%A: tensor, %B: tensor, %C: tensor) { - // This could include arbitrary code to implement the GEMM operation - // For example, calling into the actual kernel library - "some.custom_code"() : () -> () - kernel.yield + // Simple GEMM operation definition with linalg.generic representation + kernel.defn @simple_gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Simple matrix multiplication: C = A * B + C + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor } - // GEMM operation definition with linalg.generic representation + // Scaled GEMM operation definition with alpha and beta coefficients kernel.defn @gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { - //TODO: move to function arg - //TODO: We can do const prop for alpha and beta for simple matmul match %alpha = arith.constant 1.0 : f32 %beta = arith.constant 0.0 : f32 - // Implementation using linalg.generic + // GEMM with scaling: C = alpha * A * B + beta * C %result = linalg.generic { indexing_maps = [ - affine_map<(i, j, k) -> (i, k)>, // A(i,k) - affine_map<(i, j, k) -> (k, j)>, // B(k,j) - affine_map<(i, j, k) -> (i, j)> // C(i,j) + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> ], iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : tensor, tensor) @@ -38,6 +51,61 @@ module { kernel.yield %result : tensor } + // Sum of absolute values operation (ASUM) + kernel.defn @asum_linalg(%X: tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // Sum of absolute values: result = sum_i |x_i| + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%fill : tensor) { + ^bb0(%in: f32, %out: f32): + %abs_val = math.absf %in : f32 + %result = arith.addf %abs_val, %out : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // Vector dot product + kernel.defn @dot_linalg(%X: tensor, %Y: tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // Dot product: result = sum_i x_i * y_i + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X, %Y : tensor, tensor) + outs(%fill : tensor) { + ^bb0(%x: f32, %y: f32, %out: f32): + %product = arith.mulf %x, %y : f32 + %result = arith.addf %product, %out : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // GEMM operation definition with arbitrary code implementation + kernel.defn @gemm(%A: tensor, %B: tensor, %C: tensor) { + // This could include arbitrary code to implement the GEMM operation + // For example, calling into the actual kernel library + "some.custom_code"() : () -> () + kernel.yield + } + // Batched GEMM operation definition with arbitrary code kernel.defn @batched_gemm(%A2: tensor, %B2: tensor, %C2: tensor) { // This could include arbitrary code to implement the batched GEMM operation @@ -70,27 +138,6 @@ module { kernel.yield } - // GEMM operation definition with linalg.generic representation - kernel.defn @simple_gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { - // Implementation using linalg.generic - %result = linalg.generic { - indexing_maps = [ - affine_map<(i, j, k) -> (i, k)>, // A(i,k) - affine_map<(i, j, k) -> (k, j)>, // B(k,j) - affine_map<(i, j, k) -> (i, j)> // C(i,j) - ], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B : tensor, tensor) - outs(%C : tensor) { - ^bb0(%a: f32, %b: f32, %c: f32): - %product = arith.mulf %a, %b : f32 - %result = arith.addf %product, %c : f32 - linalg.yield %result : f32 - } -> tensor - kernel.yield %result : tensor - } - - // Index of maximum absolute value operation definition with arbitrary code kernel.defn @iamax(%X: tensor) -> tensor { // This could include arbitrary code to find the index of max absolute value @@ -172,70 +219,5 @@ module { kernel.yield %result : tensor } - // Sum of absolute values operation definition with linalg.generic representation - kernel.defn @asum_linalg(%X: tensor) -> tensor { - // Create an initial tensor to store the result sum - %c0 = arith.constant 0.0 : f32 - %init = tensor.empty() : tensor - %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor - - // Implementation using linalg.generic - %result = linalg.generic { - indexing_maps = [ - affine_map<(i) -> (i)>, // Input vector - affine_map<(i) -> ()> // Result scalar (sum) - ], - iterator_types = ["reduction"] - } ins(%X : tensor) - outs(%fill : tensor) { - ^bb0(%in: f32, %out: f32): - %abs_val = math.absf %in : f32 - %result = arith.addf %abs_val, %out : f32 - linalg.yield %result : f32 - } -> tensor - kernel.yield %result : tensor - } - - // Mathematical definitions (commented, for reference) - // kernel.defn @gemm(...) { - // C(i,j) += alpha * A(i,k) * B(k,j); - // } - - // kernel.defn @batched_gemm(...) { - // C(b,i,j) += alpha * A(b,i,k) * B(b,k,j); - // } - - // kernel.defn @iamax(...) { - // result = argmax_i |x_i|; - // } - - // kernel.defn @iamin(...) { - // result = argmin_i |x_i|; - // } - - // kernel.defn @asum(...) { - // result = sum_i |x_i|; - // } } - - //Func that uses simple gemm - func.func @simple_gemm(%A: tensor, %B: tensor, %C: tensor) -> tensor { - // Implementation using linalg.generic - %result = linalg.generic { - indexing_maps = [ - affine_map<(i, j, k) -> (i, k)>, // A(i,k) - affine_map<(i, j, k) -> (k, j)>, // B(k,j) - affine_map<(i, j, k) -> (i, j)> // C(i,j) - ], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B : tensor, tensor) - outs(%C : tensor) { - ^bb0(%a: f32, %b: f32, %c: f32): - %product = arith.mulf %a, %b : f32 - %result = arith.addf %product, %c : f32 - linalg.yield %result : f32 - } -> tensor - return %result : tensor - } - } \ No newline at end of file diff --git a/generic_solver/kernel_library_simple.mlir b/generic_solver/kernel_library_simple.mlir index 7b31faa86aa6..dad0c3c7d68e 100644 --- a/generic_solver/kernel_library_simple.mlir +++ b/generic_solver/kernel_library_simple.mlir @@ -27,10 +27,7 @@ module { } // Scaled GEMM operation definition with alpha and beta coefficients - kernel.defn @gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { - %alpha = arith.constant 1.0 : f32 - %beta = arith.constant 0.0 : f32 - + kernel.defn @gemm_linalg(%A: tensor, %B: tensor, %C: tensor, %alpha: f32, %beta: f32) -> tensor { // GEMM with scaling: C = alpha * A * B + beta * C %result = linalg.generic { indexing_maps = [ @@ -52,11 +49,7 @@ module { } // Sum of absolute values operation (ASUM) - kernel.defn @asum_linalg(%X: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %init = tensor.empty() : tensor - %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor - + kernel.defn @asum_linalg(%X: tensor, %init: tensor) -> tensor { // Sum of absolute values: result = sum_i |x_i| %result = linalg.generic { indexing_maps = [ @@ -65,7 +58,7 @@ module { ], iterator_types = ["reduction"] } ins(%X : tensor) - outs(%fill : tensor) { + outs(%init : tensor) { ^bb0(%in: f32, %out: f32): %abs_val = math.absf %in : f32 %result = arith.addf %abs_val, %out : f32 @@ -75,11 +68,7 @@ module { } // Vector dot product - kernel.defn @dot_linalg(%X: tensor, %Y: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %init = tensor.empty() : tensor - %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor - + kernel.defn @dot_linalg(%X: tensor, %Y: tensor, %init: tensor) -> tensor { // Dot product: result = sum_i x_i * y_i %result = linalg.generic { indexing_maps = [ @@ -89,7 +78,7 @@ module { ], iterator_types = ["reduction"] } ins(%X, %Y : tensor, tensor) - outs(%fill : tensor) { + outs(%init : tensor) { ^bb0(%x: f32, %y: f32, %out: f32): %product = arith.mulf %x, %y : f32 %result = arith.addf %product, %out : f32 @@ -97,5 +86,55 @@ module { } -> tensor kernel.yield %result : tensor } + + // Index of maximum absolute value operation definition with linalg.generic representation + kernel.defn @iamax_linalg(%X: tensor, %init: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (index) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%init : tensor) { + ^bb0(%in: f32, %out: i32): + %idx = linalg.index 0 : index + %abs_val = math.absf %in : f32 + %curr_max_idx = arith.index_cast %out : i32 to index + %curr_max = tensor.extract %X[%curr_max_idx] : tensor + %curr_max_abs = math.absf %curr_max : f32 + %cmp = arith.cmpf ogt, %abs_val, %curr_max_abs : f32 + %new_idx = arith.select %cmp, %idx, %curr_max_idx : index + %result = arith.index_cast %new_idx : index to i32 + linalg.yield %result : i32 + } -> tensor + kernel.yield %result : tensor + } + + // Index of minimum absolute value operation definition with linalg.generic representation + kernel.defn @iamin_linalg(%X: tensor, %init: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (index) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%init : tensor) { + ^bb0(%in: f32, %out: i32): + %idx = linalg.index 0 : index + %abs_val = math.absf %in : f32 + %curr_min_idx = arith.index_cast %out : i32 to index + %curr_min = tensor.extract %X[%curr_min_idx] : tensor + %curr_min_abs = math.absf %curr_min : f32 + %cmp = arith.cmpf olt, %abs_val, %curr_min_abs : f32 + %new_idx = arith.select %cmp, %idx, %curr_min_idx : index + %result = arith.index_cast %new_idx : index to i32 + linalg.yield %result : i32 + } -> tensor + kernel.yield %result : tensor + } } } \ No newline at end of file From 4a95c7f58f2230c9d84e06572e0bb8adacf30698 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 26 Jun 2025 08:30:30 -0700 Subject: [PATCH 60/77] Removing redundant file --- generic_solver/kernel_library.mlir | 223 ----------------------------- 1 file changed, 223 deletions(-) delete mode 100644 generic_solver/kernel_library.mlir diff --git a/generic_solver/kernel_library.mlir b/generic_solver/kernel_library.mlir deleted file mode 100644 index d8e70186f618..000000000000 --- a/generic_solver/kernel_library.mlir +++ /dev/null @@ -1,223 +0,0 @@ -// Kernel Library - Reusable kernel definitions -// This file contains a collection of kernel definitions that can be loaded -// by the linalg-to-kernel pass and applied to different MLIR modules. - -module { - // Collection of kernel operation definitions - kernel.defn_collection { - - // Simple GEMM operation definition with linalg.generic representation - kernel.defn @simple_gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { - // Simple matrix multiplication: C = A * B + C - %result = linalg.generic { - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B : tensor, tensor) - outs(%C : tensor) { - ^bb0(%a: f32, %b: f32, %c: f32): - %product = arith.mulf %a, %b : f32 - %result = arith.addf %product, %c : f32 - linalg.yield %result : f32 - } -> tensor - kernel.yield %result : tensor - } - - // Scaled GEMM operation definition with alpha and beta coefficients - kernel.defn @gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { - %alpha = arith.constant 1.0 : f32 - %beta = arith.constant 0.0 : f32 - - // GEMM with scaling: C = alpha * A * B + beta * C - %result = linalg.generic { - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d2, d1)>, - affine_map<(d0, d1, d2) -> (d0, d1)> - ], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%A, %B : tensor, tensor) - outs(%C : tensor) { - ^bb0(%a: f32, %b: f32, %c: f32): - %product = arith.mulf %a, %b : f32 - %scaled = arith.mulf %product, %alpha : f32 - %scaled_c = arith.mulf %c, %beta : f32 - %result = arith.addf %scaled, %scaled_c : f32 - linalg.yield %result : f32 - } -> tensor - kernel.yield %result : tensor - } - - // Sum of absolute values operation (ASUM) - kernel.defn @asum_linalg(%X: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %init = tensor.empty() : tensor - %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor - - // Sum of absolute values: result = sum_i |x_i| - %result = linalg.generic { - indexing_maps = [ - affine_map<(d0) -> (d0)>, - affine_map<(d0) -> ()> - ], - iterator_types = ["reduction"] - } ins(%X : tensor) - outs(%fill : tensor) { - ^bb0(%in: f32, %out: f32): - %abs_val = math.absf %in : f32 - %result = arith.addf %abs_val, %out : f32 - linalg.yield %result : f32 - } -> tensor - kernel.yield %result : tensor - } - - // Vector dot product - kernel.defn @dot_linalg(%X: tensor, %Y: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %init = tensor.empty() : tensor - %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor - - // Dot product: result = sum_i x_i * y_i - %result = linalg.generic { - indexing_maps = [ - affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>, - affine_map<(d0) -> ()> - ], - iterator_types = ["reduction"] - } ins(%X, %Y : tensor, tensor) - outs(%fill : tensor) { - ^bb0(%x: f32, %y: f32, %out: f32): - %product = arith.mulf %x, %y : f32 - %result = arith.addf %product, %out : f32 - linalg.yield %result : f32 - } -> tensor - kernel.yield %result : tensor - } - - // GEMM operation definition with arbitrary code implementation - kernel.defn @gemm(%A: tensor, %B: tensor, %C: tensor) { - // This could include arbitrary code to implement the GEMM operation - // For example, calling into the actual kernel library - "some.custom_code"() : () -> () - kernel.yield - } - - // Batched GEMM operation definition with arbitrary code - kernel.defn @batched_gemm(%A2: tensor, %B2: tensor, %C2: tensor) { - // This could include arbitrary code to implement the batched GEMM operation - "some.custom_code"() : () -> () - kernel.yield - } - - // Batched GEMM operation definition with linalg.generic representation - kernel.defn @batched_gemm_linalg(%A2: tensor, %B2: tensor, %C2: tensor) { - %alpha = arith.constant 1.0 : f32 - %beta = arith.constant 0.0 : f32 - - // Implementation using linalg.generic - %result = linalg.generic { - indexing_maps = [ - affine_map<(b, i, j, k) -> (b, i, k)>, // A(b,i,k) - affine_map<(b, i, j, k) -> (b, k, j)>, // B(b,k,j) - affine_map<(b, i, j, k) -> (b, i, j)> // C(b,i,j) - ], - iterator_types = ["parallel", "parallel", "parallel", "reduction"] - } ins(%A2, %B2 : tensor, tensor) - outs(%C2 : tensor) { - ^bb0(%a: f32, %b: f32, %c: f32): - %product = arith.mulf %a, %b : f32 - %scaled = arith.mulf %product, %alpha : f32 - %scaled_c = arith.mulf %c, %beta : f32 - %result = arith.addf %scaled, %scaled_c : f32 - linalg.yield %result : f32 - } -> tensor - kernel.yield - } - - // Index of maximum absolute value operation definition with arbitrary code - kernel.defn @iamax(%X: tensor) -> tensor { - // This could include arbitrary code to find the index of max absolute value - %result = "some.custom_code"() : () -> tensor - kernel.yield %result : tensor - } - - // Index of maximum absolute value operation definition with linalg.generic representation - kernel.defn @iamax_linalg(%X: tensor) -> tensor { - // Create an initial tensor to store the result index - %c0 = arith.constant 0 : i32 - %init = tensor.empty() : tensor - %fill = linalg.fill ins(%c0 : i32) outs(%init : tensor) -> tensor - - // Implementation using linalg.generic - %result = linalg.generic { - indexing_maps = [ - affine_map<(i) -> (i)>, // Input vector - affine_map<(i) -> ()> // Result scalar (index) - ], - iterator_types = ["reduction"] - } ins(%X : tensor) - outs(%fill : tensor) { - ^bb0(%in: f32, %out: i32): - %idx = linalg.index 0 : index - %abs_val = math.absf %in : f32 - %curr_max_idx = arith.index_cast %out : i32 to index - %curr_max = tensor.extract %X[%curr_max_idx] : tensor - %curr_max_abs = math.absf %curr_max : f32 - %cmp = arith.cmpf ogt, %abs_val, %curr_max_abs : f32 - %new_idx = arith.select %cmp, %idx, %curr_max_idx : index - %result = arith.index_cast %new_idx : index to i32 - linalg.yield %result : i32 - } -> tensor - kernel.yield %result : tensor - } - - // Index of minimum absolute value operation definition with arbitrary code - kernel.defn @iamin(%X: tensor) -> tensor { - // This could include arbitrary code to find the index of min absolute value - %result = "some.custom_code"() : () -> tensor - kernel.yield %result : tensor - } - - // Index of minimum absolute value operation definition with linalg.generic representation - kernel.defn @iamin_linalg(%X: tensor) -> tensor { - // Create an initial tensor to store the result index - %c0 = arith.constant 0 : i32 - %init = tensor.empty() : tensor - %fill = linalg.fill ins(%c0 : i32) outs(%init : tensor) -> tensor - - // Implementation using linalg.generic - %result = linalg.generic { - indexing_maps = [ - affine_map<(i) -> (i)>, // Input vector - affine_map<(i) -> ()> // Result scalar (index) - ], - iterator_types = ["reduction"] - } ins(%X : tensor) - outs(%fill : tensor) { - ^bb0(%in: f32, %out: i32): - %idx = linalg.index 0 : index - %abs_val = math.absf %in : f32 - %curr_min_idx = arith.index_cast %out : i32 to index - %curr_min = tensor.extract %X[%curr_min_idx] : tensor - %curr_min_abs = math.absf %curr_min : f32 - %cmp = arith.cmpf olt, %abs_val, %curr_min_abs : f32 - %new_idx = arith.select %cmp, %idx, %curr_min_idx : index - %result = arith.index_cast %new_idx : index to i32 - linalg.yield %result : i32 - } -> tensor - kernel.yield %result : tensor - } - - // Sum of absolute values operation definition with arbitrary code - kernel.defn @asum(%X: tensor) -> tensor { - // This could include arbitrary code to compute the sum of absolute values - %result = "some.custom_code"() : () -> tensor - kernel.yield %result : tensor - } - - } -} \ No newline at end of file From f1e5f029ca96e1362fe7f542bd17173197684897 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 26 Jun 2025 08:31:39 -0700 Subject: [PATCH 61/77] Renamed kernel lib --- .../{kernel_library_simple.mlir => kernel_library.mlir} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename generic_solver/{kernel_library_simple.mlir => kernel_library.mlir} (100%) diff --git a/generic_solver/kernel_library_simple.mlir b/generic_solver/kernel_library.mlir similarity index 100% rename from generic_solver/kernel_library_simple.mlir rename to generic_solver/kernel_library.mlir From e941c5efb2e4edcada9b851545305e1a2c7bd989 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 26 Jun 2025 08:34:12 -0700 Subject: [PATCH 62/77] Added min_abs_index test --- generic_solver/example.mlir | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/generic_solver/example.mlir b/generic_solver/example.mlir index 68ae2c73a3be..1dade3ef3afd 100644 --- a/generic_solver/example.mlir +++ b/generic_solver/example.mlir @@ -1,4 +1,4 @@ -//RUN: polygeist-opt --linalg-to-kernel="kernel-library-path=/home/arjaiswal/Polygeist/generic_solver/kernel_library_simple.mlir" -allow-unregistered-dialect generic_solver/example.mlir +//RUN: polygeist-opt --linalg-to-kernel="kernel-library-path=/home/arjaiswal/Polygeist/generic_solver/kernel_library.mlir" -allow-unregistered-dialect generic_solver/example.mlir // Example MLIR module demonstrating kernel operations and their linalg.generic representations module { //Func that uses simple gemm @@ -21,4 +21,29 @@ module { return %result : tensor } + // Function that uses iamin (index of minimum absolute value) + func.func @find_min_abs_index(%X: tensor, %init: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (index) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%init : tensor) { + ^bb0(%in: f32, %out: i32): + %idx = linalg.index 0 : index + %abs_val = math.absf %in : f32 + %curr_min_idx = arith.index_cast %out : i32 to index + %curr_min = tensor.extract %X[%curr_min_idx] : tensor + %curr_min_abs = math.absf %curr_min : f32 + %cmp = arith.cmpf olt, %abs_val, %curr_min_abs : f32 + %new_idx = arith.select %cmp, %idx, %curr_min_idx : index + %result = arith.index_cast %new_idx : index to i32 + linalg.yield %result : i32 + } -> tensor + return %result : tensor + } + } \ No newline at end of file From a99fad96637b369f3ce1973141c45aa15f4a394f Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 26 Jun 2025 17:51:20 -0700 Subject: [PATCH 63/77] Fixed a bunch of bugs in raiseToLinalg while raising polybench --- lib/polygeist/Passes/RaiseToLinalg.cpp | 159 +++++++++++++++++++------ 1 file changed, 122 insertions(+), 37 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index f638a26c9cd1..b8e4b232a6d2 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -100,6 +100,82 @@ AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { expr.getContext()); } +// Helper function to check if an operation dominates the target region +bool dominatesTarget(Operation* op, Region* targetRegion) { + return op->getParentRegion()->isAncestor(targetRegion); +} + +Value recursiveCloneWithDominanceCheck( + OpBuilder& builder, + Value value, + Region* targetRegion, + IRMapping& mapping, + DenseSet& processedOps) { + + // If value is already mapped, return the mapped value + if (mapping.contains(value)) { + return mapping.lookup(value); + } + + // Handle block arguments + if (auto blockArg = dyn_cast(value)) { + if (blockArg.getParentBlock()->getParent()->isAncestor(targetRegion)) { + mapping.map(value, value); + return value; + } else { + llvm::errs() << "Non-dominating block argument encountered\n"; + return nullptr; + } + } + + Operation* defOp = value.getDefiningOp(); + if (!defOp) { + return value; + } + + // Check if this operation dominates the target region + if (dominatesTarget(defOp, targetRegion)) { + // Operation dominates, use it directly + mapping.map(value, value); + return value; + } + + // Avoid processing the same operation multiple times + if (processedOps.contains(defOp)) { + // Operation was already processed, should be in mapping + auto resultNum = cast(value).getResultNumber(); + auto mappedOp = mapping.lookup(defOp->getResult(0)).getDefiningOp(); + auto clonedValue = mappedOp->getResult(resultNum); + mapping.map(value, clonedValue); + return clonedValue; + } + + // Check if operation is safe to clone + if (!isReadOnly(defOp)) { + llvm::errs() << "Cannot clone non-read-only operation: " << *defOp << "\n"; + return nullptr; + } + + processedOps.insert(defOp); + + // Recursively process ALL operands first to populate the mapping + for (Value operand : defOp->getOperands()) { + Value clonedOperand = recursiveCloneWithDominanceCheck( + builder, operand, targetRegion, mapping, processedOps); + if (!clonedOperand) { + return nullptr; + } + // clonedOperand is automatically added to mapping by recursive call + } + + // Now clone the operation using the populated mapping + Operation* clonedOp = builder.clone(*defOp, mapping); + + // The clone automatically maps all results, so we can just return what we need + auto resultNum = cast(value).getResultNumber(); + return clonedOp->getResult(resultNum); +} + // Given an affine map `oldmap`, memref `val`, and corresponding input values // (which are a list of indicies, then symbols), and a set of loop indices // `indices` produce the following: @@ -241,42 +317,50 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, legal = true; SmallVector sizes(idx_sizes.size(), mlir::ShapedType::kDynamic); for (auto sz : idx_sizes) { - // Check if the symbol value is read-only or defined in a scope where it is - // always visible. - if (auto ba = dyn_cast(sz)) { - // check if it dominates the current scope - if (ba.getParentBlock()->getParent()->isAncestor( - builder.getBlock()->getParent())) - operands_without_indices.push_back(sz); - else { - llvm::errs() << " value is a non-dominating block arg: " << sz << "\n"; - legal = false; - assert(false); - return nullptr; - } - } else { - auto op = sz.getDefiningOp(); - // check if this dominates the current scope - if (op->getParentRegion()->isAncestor(builder.getBlock()->getParent())) { - operands_without_indices.push_back(sz); - } else if (isReadOnly(op)) { - // if not, check if it is readnone - // Technically this isn't quite sufficient yet, and does require that - // the operands to this op are also able to be hoisted, but for now we - // will assume this - auto op2 = builder.clone(*op); - operands_without_indices.push_back( - op2->getResult(cast(sz).getResultNumber())); - } else { - llvm::errs() << " op is not readonly: " << *op << "\n"; - // if so clone it in the right scope - // otherwise set illegal and don't continue - legal = false; - assert(false); - return nullptr; - } - } + DenseSet processedOps; + IRMapping mapping; + auto clonedOp = recursiveCloneWithDominanceCheck(builder, sz, builder.getBlock()->getParent(), mapping, processedOps); + operands_without_indices.push_back(clonedOp); } + + //for (auto sz : idx_sizes) { + // // Check if the symbol value is read-only or defined in a scope where it is + // // always visible. + // if (auto ba = dyn_cast(sz)) { + // // check if it dominates the current scope + // if (ba.getParentBlock()->getParent()->isAncestor( + // builder.getBlock()->getParent())) + // operands_without_indices.push_back(sz); + // else { + // llvm::errs() << " value is a non-dominating block arg: " << sz << "\n"; + // legal = false; + // assert(false); + // return nullptr; + // } + // } else { + // auto op = sz.getDefiningOp(); + // // check if this dominates the current scope + // if (op->getParentRegion()->isAncestor(builder.getBlock()->getParent())) { + // operands_without_indices.push_back(sz); + // } else if (isReadOnly(op)) { + // // if not, check if it is readnone + // // Technically this isn't quite sufficient yet, and does require that + // // the operands to this op are also able to be hoisted, but for now we + // // will assume this + // // We need to clone the op along and check if it's operands are dominating or not, else do a recursive clone + // auto op2 = builder.clone(*op); + // operands_without_indices.push_back( + // op2->getResult(cast(sz).getResultNumber())); + // } else { + // llvm::errs() << " op is not readonly: " << *op << "\n"; + // // if so clone it in the right scope + // // otherwise set illegal and don't continue + // legal = false; + // assert(false); + // return nullptr; + // } + // } + //} auto ty = MemRefType::get( sizes, cast(memref_val.getType()).getElementType()); @@ -871,7 +955,6 @@ struct AffineForOpRaising : public OpRewritePattern { // TODO presently if linalg generic exists, assert there are no load/stores if ((linalgGenerics.size() > 0) && ((loads.size() != 0) || (stores.size() != 0))) { - assert(false); return failure(); } @@ -953,6 +1036,8 @@ struct AffineForOpRaising : public OpRewritePattern { for (auto genPair : linalgGenerics) { auto genOp = genPair.second; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(genOp); auto &genBlock = genOp->getRegion(0).front(); auto term = genBlock.getTerminator(); mlir::IRMapping map; @@ -964,7 +1049,7 @@ struct AffineForOpRaising : public OpRewritePattern { rewriter.clone(op, map); } for (auto op : term->getOperands()) { - toreturn.push_back(map.lookup(op)); + toreturn.push_back(map.lookupOrDefault(op)); } // llvm::errs() << genOp->getParentOfType() << "\n"; rewriter.eraseOp(genOp); From 4e782d58db4a4be5bc7ae952f74acafad6c28bd6 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 27 Jun 2025 22:56:27 -0700 Subject: [PATCH 64/77] Fixed raise to linalg and canonicalizer to generate subview --- lib/polygeist/Ops.cpp | 224 ++++++++++++++++++++++++- lib/polygeist/Passes/RaiseToLinalg.cpp | 60 ++++++- 2 files changed, 269 insertions(+), 15 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 07f0cab20f0c..d91668145ab3 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -4508,7 +4508,6 @@ struct MergeNestedAffineParallelIf return success(); } }; - struct MergeParallelInductions : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -5805,6 +5804,156 @@ struct SubMapOpCanonicalize : public OpRewritePattern { } }; +struct StrideAndBound { + int64_t stride; + int64_t lowerBound; + unsigned dimOrSymbol; // Which dimension/symbol this applies to + bool isDimension; // true if dimension, false if symbol + + StrideAndBound(int64_t s, int64_t lb, unsigned idx, bool isDim) + : stride(s), lowerBound(lb), dimOrSymbol(idx), isDimension(isDim) {} +}; + +struct ExpressionAnalysis { + SmallVector coefficients; // Coefficients for dims/symbols + int64_t constantTerm = 0; // Pure constant term + + void addDimCoeff(unsigned dim, int64_t coeff) { + coefficients.emplace_back(coeff, 0, dim, true); + } + + void addSymCoeff(unsigned sym, int64_t coeff) { + coefficients.emplace_back(coeff, 0, sym, false); + } +}; + +// Recursively analyze an affine expression to extract coefficients and constants +static ExpressionAnalysis analyzeAffineExpression(AffineExpr expr) { + ExpressionAnalysis result; + + if (auto constExpr = expr.dyn_cast()) { + // Pure constant + result.constantTerm = constExpr.getValue(); + + } else if (auto dimExpr = expr.dyn_cast()) { + // Single dimension with coefficient 1 + result.addDimCoeff(dimExpr.getPosition(), 1); + + } else if (auto symExpr = expr.dyn_cast()) { + // Single symbol with coefficient 1 + result.addSymCoeff(symExpr.getPosition(), 1); + + } else if (auto binaryExpr = expr.dyn_cast()) { + auto lhs = binaryExpr.getLHS(); + auto rhs = binaryExpr.getRHS(); + + if (binaryExpr.getKind() == AffineExprKind::Add) { + // Addition: combine results from both sides + auto lhsAnalysis = analyzeAffineExpression(lhs); + auto rhsAnalysis = analyzeAffineExpression(rhs); + + result.coefficients.append(lhsAnalysis.coefficients); + result.coefficients.append(rhsAnalysis.coefficients); + result.constantTerm = lhsAnalysis.constantTerm + rhsAnalysis.constantTerm; + + } else if (binaryExpr.getKind() == AffineExprKind::Mul) { + // Multiplication: one side should be constant, other should be dim/symbol + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); + + if (lhsConst && !rhsConst) { + // Constant * expr + auto rhsAnalysis = analyzeAffineExpression(rhs); + for (auto &coeff : rhsAnalysis.coefficients) { + coeff.stride *= lhsConst.getValue(); + } + result.coefficients = std::move(rhsAnalysis.coefficients); + result.constantTerm = rhsAnalysis.constantTerm * lhsConst.getValue(); + + } else if (rhsConst && !lhsConst) { + // expr * Constant + auto lhsAnalysis = analyzeAffineExpression(lhs); + for (auto &coeff : lhsAnalysis.coefficients) { + coeff.stride *= rhsConst.getValue(); + } + result.coefficients = std::move(lhsAnalysis.coefficients); + result.constantTerm = lhsAnalysis.constantTerm * rhsConst.getValue(); + + } else if (lhsConst && rhsConst) { + // Constant * Constant + result.constantTerm = lhsConst.getValue() * rhsConst.getValue(); + } + // Note: expr * expr is not affine, so we don't handle it + + } else if (binaryExpr.getKind() == AffineExprKind::Mod) { + // Modulo: more complex, for now just mark as having the base expression + auto lhsAnalysis = analyzeAffineExpression(lhs); + result.coefficients = std::move(lhsAnalysis.coefficients); + result.constantTerm = lhsAnalysis.constantTerm; + + } else if (binaryExpr.getKind() == AffineExprKind::FloorDiv || + binaryExpr.getKind() == AffineExprKind::CeilDiv) { + // Division: handle simple cases where RHS is constant + if (auto rhsConst = rhs.dyn_cast()) { + auto lhsAnalysis = analyzeAffineExpression(lhs); + for (auto &coeff : lhsAnalysis.coefficients) { + coeff.stride = coeff.stride / rhsConst.getValue(); + } + result.coefficients = std::move(lhsAnalysis.coefficients); + result.constantTerm = lhsAnalysis.constantTerm / rhsConst.getValue(); + } + } + } + + return result; +} + +struct MapAnalysis { + SmallVector outputAnalyses; + + // Get all unique strides from all outputs + SmallVector getAllStrides() const { + SmallVector strides; + llvm::DenseSet seen; + + for (const auto &analysis : outputAnalyses) { + for (const auto &coeff : analysis.coefficients) { + // TODO: Need to add a check that if more than one coeffs in an outputAnalysis + // then we need to return failure. + strides.push_back(coeff.stride); + } + } + return strides; + } + + // Get all lower bounds (constant terms) from all outputs + SmallVector getAllLowerBounds() const { + SmallVector bounds; + for (const auto &analysis : outputAnalyses) { + bounds.push_back(analysis.constantTerm); + } + return bounds; + } +}; + +// Main function to analyze an affine map +static MapAnalysis analyzeAffineMap(AffineMap map) { + MapAnalysis result; + + for (auto expr : map.getResults()) { + result.outputAnalyses.push_back(analyzeAffineExpression(expr)); + } + + return result; +} + +// Extract both strides and bounds +std::pair, SmallVector> +extractStridesAndBounds(AffineMap map) { + auto analysis = analyzeAffineMap(map); + return {analysis.getAllStrides(), analysis.getAllLowerBounds()}; +} + struct LinalgOfSubmap : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::GenericOp genericOp, @@ -5819,6 +5968,7 @@ struct LinalgOfSubmap : public OpRewritePattern { SmallVector listOfAllocas; SmallVector listOfNewMaps; SmallVector listOfNewInputs, listOfNewOutputs; + // auto mapAttrsArr = genericOp.getIndexingMaps(); // for(auto mapAttr: mapAttrsArr) { // AffineMap map = mapAttr.cast().getValue(); @@ -5831,13 +5981,46 @@ struct LinalgOfSubmap : public OpRewritePattern { } else if (auto subMap = dyn_cast(inp.getDefiningOp())) { auto source_memref = subMap.getMemref(); - // if (auto blockArg = dyn_cast_or_null(op)) { + + //Create a new memref.subview op from the given submap and sizes + Value stride = rewriter.create(source_memref.getLoc(), 1); + + //sizesauto blockArg = dyn_cast_or_null(op)) { // if(auto source_alloca = // dyn_cast(source_memref.getDefiningOp())) //{ auto map = subMap.getMap(); - listOfNewMaps.push_back(map); - listOfNewInputs.push_back(source_memref); + + ////Create sizes from the submap + auto sizes = subMap.getSizes(); + + // Create a subview op using lower bound, stride and size + // Convert AffineApplyOp to its result Value and wrap in ValueRange + auto [strides, lowerBounds] = extractStridesAndBounds(map); + SmallVector offsetValues, sizeValues, strideValues; + for (int64_t offset : lowerBounds) { + offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); + } + for (int64_t stride : strides) { + strideValues.push_back(rewriter.getI64IntegerAttr(stride)); + } + for (Value size : sizes) { + sizeValues.push_back(size); + } + auto subViewOp = rewriter.create( + source_memref.getLoc(), // Location + source_memref, // Source memref + offsetValues, // Offsets (array) + sizeValues, // Sizes (array) + strideValues // Strides (array) + ); + auto subViewType = subViewOp.getType().cast(); + unsigned rank = subViewType.getRank(); + auto identityMap = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); + + listOfNewMaps.push_back(identityMap); + listOfNewInputs.push_back(subViewOp); + //} // else { // assert(false && "Only expect allocaOp as source for submap @@ -5855,8 +6038,36 @@ struct LinalgOfSubmap : public OpRewritePattern { dyn_cast(out.getDefiningOp())) { auto source_memref = subMap.getMemref(); auto map = subMap.getMap(); - listOfNewMaps.push_back(map); - listOfNewOutputs.push_back(source_memref); + + //Create sizes from the submap + auto sizes = subMap.getSizes(); + + // Create a subview op using lower bound, stride and size + // Convert AffineApplyOp to its result Value and wrap in ValueRange + auto [strides, lowerBounds] = extractStridesAndBounds(map); + + SmallVector offsetValues, sizeValues, strideValues; + for (int64_t offset : lowerBounds) { + offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); + } + for (int64_t stride : strides) { + strideValues.push_back(rewriter.getI64IntegerAttr(stride)); + } + for (Value size : sizes) { + sizeValues.push_back(size); + } + auto subViewOp = rewriter.create( + source_memref.getLoc(), // Location + source_memref, // Source memref + offsetValues, // Offsets (array) + sizeValues, // Sizes (array) + strideValues // Strides (array) + ); + auto subViewType = subViewOp.getType().cast(); + unsigned rank = subViewType.getRank(); + auto identityMap = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); + listOfNewMaps.push_back(identityMap); + listOfNewOutputs.push_back(subViewOp); } else { listOfNewOutputs.push_back(out); } @@ -6433,3 +6644,4 @@ void polygeist::SubmapOp::getCanonicalizationPatterns( results.insert(context); // results.insert(context); } + diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index b8e4b232a6d2..f670edd02c25 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -176,6 +176,29 @@ Value recursiveCloneWithDominanceCheck( return clonedOp->getResult(resultNum); } +// Check if the affine apply is a constant and return the constant value +std::optional getConstantFromAffineApply(AffineApplyOp applyOp) { + AffineMap map = applyOp.getAffineMap(); + + // Must have no dimensions and no symbols + if (map.getNumDims() != 0 || map.getNumSymbols() != 0) { + return std::nullopt; + } + + // Must have exactly one result that is a constant + if (map.getNumResults() != 1) { + return std::nullopt; + } + + // Check if the single result is a constant expression + AffineExpr result = map.getResult(0); + if (auto constExpr = result.dyn_cast()) { + return constExpr.getValue(); + } + + return std::nullopt; +} + // Given an affine map `oldmap`, memref `val`, and corresponding input values // (which are a list of indicies, then symbols), and a set of loop indices // `indices` produce the following: @@ -190,9 +213,12 @@ Value recursiveCloneWithDominanceCheck( // variable. And it is returned true, only if index was not encountered in // oldmap operands and check_reduction was set true. Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, - Value memref_val, Value index, Value bound, + Value memref_val, Value index, Value bound, AffineApplyOp lower_bound, int firstNDims, ValueRange oldmap_operands, Value origmemref, bool &check_reduction) { + + int lower_bound_val = getConstantFromAffineApply(lower_bound).value_or(0); + assert(oldmap_operands.size() == oldmap.getNumSymbols() + oldmap.getNumDims()); // Operands which don't correspond to indices @@ -256,7 +282,7 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, dimReplacements.push_back(builder.getAffineDimExpr(validDims)); validDims++; } else if (i == dimidx) { - dimReplacements.push_back(builder.getAffineDimExpr(validDims)); + dimReplacements.push_back(builder.getAffineDimExpr(validDims) + builder.getAffineConstantExpr(lower_bound_val)); validDims++; } else { // TODO: Why are we using symbol here instead of dim? @@ -268,7 +294,7 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, SmallVector symReplacements; for (int i = 0; i < oldmap.getNumSymbols(); i++) { if (i + oldmap.getNumDims() == dimidx) { - symReplacements.push_back(builder.getAffineDimExpr(validDims)); + symReplacements.push_back(builder.getAffineDimExpr(validDims) + builder.getAffineConstantExpr(lower_bound_val)); validDims++; } else { symReplacements.push_back(builder.getAffineSymbolExpr(validSims)); @@ -299,8 +325,8 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, } assert(validSims == operands_without_indices.size()); auto map2 = oldmap.replaceDimsAndSymbols(dimReplacements, symReplacements, - firstNDims + 1, - operands_without_indices.size()); + firstNDims + 1/*Number of dims in new map*/, + operands_without_indices.size() /*Number of symbols in new map*/); SmallVector idx_sizes; for (size_t i = 0; i < firstNDims; i++) { @@ -364,6 +390,22 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, auto ty = MemRefType::get( sizes, cast(memref_val.getType()).getElementType()); + ////TODO: Can we have a case where stride is not 1? + //Value stride = builder.create(memref_val.getLoc(), 1); + + //// Create a subview op using lower bound, stride and size + //// Convert AffineApplyOp to its result Value and wrap in ValueRange + //Value lowerBoundValue = lower_bound.getResult(); + //auto subViewOp = builder.create( + // memref_val.getLoc(), // Location + // memref_val, // Source memref + // ValueRange{lowerBoundValue}, // Offsets (array) + // ValueRange{bound}, // Sizes (array) + // ValueRange{stride} // Strides (array) + //); + + //Value subview = subViewOp.getResult(); + return builder.create( memref_val.getLoc(), ty, memref_val, operands_without_indices, map2); } @@ -843,7 +885,7 @@ struct AffineForOpRaising : public OpRewritePattern { size_t firstNDims = lgMap.getNumDims(); check_reduction = false; auto newMemref = remap_in_affine_dim( - legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, lbValue, firstNDims, ValueRange(lgOperands), input, check_reduction); if (!legal) return failure(); @@ -882,7 +924,7 @@ struct AffineForOpRaising : public OpRewritePattern { size_t firstNDims = lgMap.getNumDims(); check_reduction = true; auto newMemref = remap_in_affine_dim( - legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, lbValue, firstNDims, ValueRange(lgOperands), output, check_reduction); if (!legal) return failure(); @@ -911,7 +953,7 @@ struct AffineForOpRaising : public OpRewritePattern { check_reduction = false; auto newMemref = remap_in_affine_dim( legal, rewriter, load.getAffineMap(), load.getMemref(), - loop.getInductionVar(), loopSize, firstNDims, load.getMapOperands(), + loop.getInductionVar(), loopSize, lbValue, firstNDims, load.getMapOperands(), load.getMemref(), check_reduction); if (!legal) @@ -939,7 +981,7 @@ struct AffineForOpRaising : public OpRewritePattern { check_reduction = true; auto newMemref = remap_in_affine_dim( legal, rewriter, store.getAffineMap(), store.getMemref(), - loop.getInductionVar(), loopSize, firstNDims, store.getMapOperands(), + loop.getInductionVar(), loopSize, lbValue, firstNDims, store.getMapOperands(), store.getMemref(), check_reduction); if (!legal) { From bd15b6dd29ff0615e8e23a1fd292ce288e82bd43 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 31 Jul 2025 15:37:47 -0700 Subject: [PATCH 65/77] Fixed submap simplification, improved raisedToLinalg to work with non constant bounds, added debufferization to work with allocs --- lib/polygeist/Ops.cpp | 464 ++++++++++++++------- lib/polygeist/Passes/LinalgDebufferize.cpp | 36 +- lib/polygeist/Passes/RaiseToLinalg.cpp | 26 +- 3 files changed, 354 insertions(+), 172 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index d91668145ab3..f9a607b66db2 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -4517,7 +4517,7 @@ struct MergeParallelInductions // Reductions are not supported yet. if (!op.getReductions().empty()) return failure(); - + auto getIndUsage = [&op](AffineExpr cst, ValueRange operands, std::map &indUsage, bool &legal) -> AffineExpr { @@ -5954,167 +5954,331 @@ extractStridesAndBounds(AffineMap map) { return {analysis.getAllStrides(), analysis.getAllLowerBounds()}; } -struct LinalgOfSubmap : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, - PatternRewriter &rewriter) const override { - // Check body content - auto module = genericOp->getParentOfType(); - Region &genericBody = genericOp.getRegion(); - Block &entryBlock = genericBody.front(); - ValueRange blockArgs = entryBlock.getArguments(); - auto inputs = genericOp.getInputs(); - auto outputs = genericOp.getOutputs(); - SmallVector listOfAllocas; - SmallVector listOfNewMaps; - SmallVector listOfNewInputs, listOfNewOutputs; - - // auto mapAttrsArr = genericOp.getIndexingMaps(); - // for(auto mapAttr: mapAttrsArr) { - // AffineMap map = mapAttr.cast().getValue(); - // if(map == convMap[0] && !mapped[0]) { - // } - // } - for (auto inp : inputs) { - if (auto blkArg = dyn_cast(inp)) { - listOfNewInputs.push_back(inp); - } else if (auto subMap = - dyn_cast(inp.getDefiningOp())) { - auto source_memref = subMap.getMemref(); +// Helper function to check if an expression is a simple offset + stride pattern +static bool isSimpleOffsetStride(AffineExpr expr) { + // Check if expression is of the form: d0 + constant, d0 * constant + constant, etc. + if (auto dimExpr = expr.dyn_cast()) { + return true; // Simple dimension access + } + + if (auto constExpr = expr.dyn_cast()) { + return true; // Constant offset + } + + if (auto binaryExpr = expr.dyn_cast()) { + auto kind = binaryExpr.getKind(); - //Create a new memref.subview op from the given submap and sizes - Value stride = rewriter.create(source_memref.getLoc(), 1); - - //sizesauto blockArg = dyn_cast_or_null(op)) { - // if(auto source_alloca = - // dyn_cast(source_memref.getDefiningOp())) - //{ - auto map = subMap.getMap(); - - ////Create sizes from the submap - auto sizes = subMap.getSizes(); - - // Create a subview op using lower bound, stride and size - // Convert AffineApplyOp to its result Value and wrap in ValueRange - auto [strides, lowerBounds] = extractStridesAndBounds(map); - SmallVector offsetValues, sizeValues, strideValues; - for (int64_t offset : lowerBounds) { - offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); - } - for (int64_t stride : strides) { - strideValues.push_back(rewriter.getI64IntegerAttr(stride)); - } - for (Value size : sizes) { - sizeValues.push_back(size); - } - auto subViewOp = rewriter.create( - source_memref.getLoc(), // Location - source_memref, // Source memref - offsetValues, // Offsets (array) - sizeValues, // Sizes (array) - strideValues // Strides (array) - ); - auto subViewType = subViewOp.getType().cast(); - unsigned rank = subViewType.getRank(); - auto identityMap = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); - - listOfNewMaps.push_back(identityMap); - listOfNewInputs.push_back(subViewOp); - - //} - // else { - // assert(false && "Only expect allocaOp as source for submap - // canonicalization right now"); return failure(); - //} - } else { - listOfNewInputs.push_back(inp); + // Allow simple addition and multiplication patterns + if (kind == AffineExprKind::Add || kind == AffineExprKind::Mul) { + return isSimpleOffsetStride(binaryExpr.getLHS()) && + isSimpleOffsetStride(binaryExpr.getRHS()); + } + + // Allow simple division by constants (for stride calculation) + if (kind == AffineExprKind::FloorDiv || kind == AffineExprKind::CeilDiv) { + if (auto rhsConst = binaryExpr.getRHS().dyn_cast()) { + return rhsConst.getValue() > 0 && isSimpleOffsetStride(binaryExpr.getLHS()); } } + } + + return false; +} - for (auto out : outputs) { - if (auto blkArg = dyn_cast(out)) { - listOfNewOutputs.push_back(out); - } else if (auto subMap = - dyn_cast(out.getDefiningOp())) { - auto source_memref = subMap.getMemref(); - auto map = subMap.getMap(); - - //Create sizes from the submap - auto sizes = subMap.getSizes(); - - // Create a subview op using lower bound, stride and size - // Convert AffineApplyOp to its result Value and wrap in ValueRange - auto [strides, lowerBounds] = extractStridesAndBounds(map); +// Main function to check if SubmapOp can be converted to SubViewOp +static bool canConvertSubmapToSubView(polygeist::SubmapOp submapOp) { + auto map = submapOp.getMap(); + auto sizes = submapOp.getSizes(); + auto symbols = submapOp.getSymbols(); + auto source_memref = submapOp.getMemref(); + + // 1. Identity maps are always valid + if (map.isIdentity()) { + return true; + } + + // 2. Check if we can extract meaningful strides and bounds + auto [strides, lowerBounds] = extractStridesAndBounds(map); + if (strides.empty() || lowerBounds.empty()) { + return false; + } + + // 3. Ensure the number of results matches expected dimensions + if (map.getNumResults() != sizes.size()) { + return false; + } + + // 4. Check each expression in the map for complexity + for (auto expr : map.getResults()) { + if (!isSimpleOffsetStride(expr)) { + return false; + } + } + + // 5. Check for unsupported complex transformations + for (auto expr : map.getResults()) { + // Reject expressions that involve multiple dimensions in complex ways + if (auto binaryExpr = expr.dyn_cast()) { + // For now, reject modulo operations as they're hard to represent in SubView + if (binaryExpr.getKind() == AffineExprKind::Mod) { + return false; + } + + // Reject complex multi-dimensional expressions + if (binaryExpr.getKind() == AffineExprKind::Mul) { + auto lhs = binaryExpr.getLHS(); + auto rhs = binaryExpr.getRHS(); - SmallVector offsetValues, sizeValues, strideValues; - for (int64_t offset : lowerBounds) { - offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); - } - for (int64_t stride : strides) { - strideValues.push_back(rewriter.getI64IntegerAttr(stride)); + // Both sides are dimensions = complex interaction + if (lhs.isa() && rhs.isa()) { + return false; } - for (Value size : sizes) { - sizeValues.push_back(size); + + // Multiplication by symbols might be too complex for simple SubView + if (lhs.isa() || rhs.isa()) { + // Allow simple symbol multiplication, but check it's not too complex + if (!lhs.isa() && !rhs.isa()) { + return false; + } } - auto subViewOp = rewriter.create( - source_memref.getLoc(), // Location - source_memref, // Source memref - offsetValues, // Offsets (array) - sizeValues, // Sizes (array) - strideValues // Strides (array) - ); - auto subViewType = subViewOp.getType().cast(); - unsigned rank = subViewType.getRank(); - auto identityMap = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); - listOfNewMaps.push_back(identityMap); - listOfNewOutputs.push_back(subViewOp); - } else { - listOfNewOutputs.push_back(out); } } - ArrayRef maps(listOfNewMaps); - // No submap ops detected - if (maps.size() == 0) + } + + // 6. Check for rank-changing transformations that SubView can't handle + auto sourceType = source_memref.getType().cast(); + auto resultType = submapOp.getType().cast(); + + // SubView can do rank-reduction, but not rank-expansion + if (resultType.getRank() > sourceType.getRank()) { + return false; + } + + return true; +} + +// Convenience function to check and extract conversion info +struct SubmapToSubViewConversionInfo { + bool isValid; + SmallVector strides; + SmallVector offsets; + SmallVector sizes; + SmallVector dynamicOffsets; // For symbol-based offsets + + SubmapToSubViewConversionInfo() : isValid(false) {} +}; + +static SubmapToSubViewConversionInfo +analyzeSubmapToSubViewConversion(polygeist::SubmapOp submapOp) { + SubmapToSubViewConversionInfo info; + + if (!canConvertSubmapToSubView(submapOp)) { + return info; // isValid = false + } + + auto map = submapOp.getMap(); + auto [strides, lowerBounds] = extractStridesAndBounds(map); + + info.isValid = true; + info.strides = strides; + info.offsets = lowerBounds; + info.sizes.append(submapOp.getSizes().begin(), submapOp.getSizes().end()); + + return info; +} + + +struct SubmapToSubviewOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(polygeist::SubmapOp submapOp, + PatternRewriter &rewriter) const override { + auto conversionInfo = analyzeSubmapToSubViewConversion(submapOp); + if (!conversionInfo.isValid) return failure(); - // If inverse permutation exists, then we can canonicalize the linalg of - // submap to linalg - // TODO: Fails for: - // 1. Maps with symbols - // 2. Maps which are not resolvable 1 to 1 with memref for all dims - if (inversePermutation(concatAffineMaps(maps))) { - StringAttr empty = StringAttr::get(genericOp.getContext()); - auto newGenericOp = rewriter.create( - genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, - listOfNewMaps, genericOp.getIteratorTypesArray(), empty, empty); - rewriter.inlineRegionBefore(genericOp.getRegion(), - newGenericOp.getRegion(), - newGenericOp.getRegion().end()); - - // auto &block = newGenericOp.getRegion().front(); - // block.addArguments(newGenericOp.getOperandTypes(), - // SmallVector(newGenericOp.getNumOperands(), - // genericOp.getLoc())); - - rewriter.replaceOp(genericOp, newGenericOp.getResults()); - return success(); - } - // for(iterate over inputs) - //{ - // gather maps - // gather submaps - // Gather affine maps from submaps - // Check over 2 iterations if all the indexes can be solved. - // Use the same logic as linalg.generic to do this. - // if success in getting vars - // replace affine map from submap to linalg.generic - // replace input memref as direct input to linalg.generic - // } - // assert(false && "inversePermutation doesn't exists for the given linalg - // generic"); - return failure(); + + SmallVector offsetValues, sizeValues, strideValues; + for (int64_t offset : conversionInfo.offsets) { + offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); + } + for (int64_t stride : conversionInfo.strides) { + strideValues.push_back(rewriter.getI64IntegerAttr(stride)); + } + for (Value size : conversionInfo.sizes) { + sizeValues.push_back(size); + } + //auto subViewOp = rewriter.create( + // source_memref.getLoc(), // Location + // source_memref, // Source memref + // offsetValues, // Offsets (array) + // sizeValues, // Sizes (array) + // strideValues // Strides (array) + //); + rewriter.replaceOpWithNewOp(submapOp, submapOp.getType(), submapOp.getMemref(), offsetValues, sizeValues, strideValues); + return success(); } }; +//struct LinalgOfSubmap : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; +// LogicalResult matchAndRewrite(linalg::GenericOp genericOp, +// PatternRewriter &rewriter) const override { +// // Check body content +// auto module = genericOp->getParentOfType(); +// Region &genericBody = genericOp.getRegion(); +// Block &entryBlock = genericBody.front(); +// ValueRange blockArgs = entryBlock.getArguments(); +// auto inputs = genericOp.getInputs(); +// auto outputs = genericOp.getOutputs(); +// SmallVector listOfAllocas; +// SmallVector listOfNewMaps; +// SmallVector listOfNewInputs, listOfNewOutputs; +// +// // auto mapAttrsArr = genericOp.getIndexingMaps(); +// // for(auto mapAttr: mapAttrsArr) { +// // AffineMap map = mapAttr.cast().getValue(); +// // if(map == convMap[0] && !mapped[0]) { +// // } +// // } +// for (auto inp : inputs) { +// if (auto blkArg = dyn_cast(inp)) { +// listOfNewInputs.push_back(inp); +// } else if (auto subMap = +// dyn_cast(inp.getDefiningOp())) { +// auto source_memref = subMap.getMemref(); +// +// //Create a new memref.subview op from the given submap and sizes +// Value stride = rewriter.create(source_memref.getLoc(), 1); +// +// //sizesauto blockArg = dyn_cast_or_null(op)) { +// // if(auto source_alloca = +// // dyn_cast(source_memref.getDefiningOp())) +// //{ +// auto map = subMap.getMap(); +// +// ////Create sizes from the submap +// auto sizes = subMap.getSizes(); +// +// // Create a subview op using lower bound, stride and size +// // Convert AffineApplyOp to its result Value and wrap in ValueRange +// auto [strides, lowerBounds] = extractStridesAndBounds(map); +// SmallVector offsetValues, sizeValues, strideValues; +// for (int64_t offset : lowerBounds) { +// offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); +// } +// for (int64_t stride : strides) { +// strideValues.push_back(rewriter.getI64IntegerAttr(stride)); +// } +// for (Value size : sizes) { +// sizeValues.push_back(size); +// } +// auto subViewOp = rewriter.create( +// source_memref.getLoc(), // Location +// source_memref, // Source memref +// offsetValues, // Offsets (array) +// sizeValues, // Sizes (array) +// strideValues // Strides (array) +// ); +// auto subViewType = subViewOp.getType().cast(); +// unsigned rank = subViewType.getRank(); +// auto identityMap = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); +// +// listOfNewMaps.push_back(identityMap); +// listOfNewInputs.push_back(subViewOp); +// +// //} +// // else { +// // assert(false && "Only expect allocaOp as source for submap +// // canonicalization right now"); return failure(); +// //} +// } else { +// listOfNewInputs.push_back(inp); +// } +// } +// +// for (auto out : outputs) { +// if (auto blkArg = dyn_cast(out)) { +// listOfNewOutputs.push_back(out); +// } else if (auto subMap = +// dyn_cast(out.getDefiningOp())) { +// auto source_memref = subMap.getMemref(); +// auto map = subMap.getMap(); +// +// //Create sizes from the submap +// auto sizes = subMap.getSizes(); +// +// // Create a subview op using lower bound, stride and size +// // Convert AffineApplyOp to its result Value and wrap in ValueRange +// auto [strides, lowerBounds] = extractStridesAndBounds(map); +// +// SmallVector offsetValues, sizeValues, strideValues; +// for (int64_t offset : lowerBounds) { +// offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); +// } +// for (int64_t stride : strides) { +// strideValues.push_back(rewriter.getI64IntegerAttr(stride)); +// } +// for (Value size : sizes) { +// sizeValues.push_back(size); +// } +// auto subViewOp = rewriter.create( +// source_memref.getLoc(), // Location +// source_memref, // Source memref +// offsetValues, // Offsets (array) +// sizeValues, // Sizes (array) +// strideValues // Strides (array) +// ); +// auto subViewType = subViewOp.getType().cast(); +// unsigned rank = subViewType.getRank(); +// auto identityMap = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); +// listOfNewMaps.push_back(identityMap); +// listOfNewOutputs.push_back(subViewOp); +// } else { +// listOfNewOutputs.push_back(out); +// } +// } +// ArrayRef maps(listOfNewMaps); +// // No submap ops detected +// if (maps.size() == 0) +// return failure(); +// // If inverse permutation exists, then we can canonicalize the linalg of +// // submap to linalg +// // TODO: Fails for: +// // 1. Maps with symbols +// // 2. Maps which are not resolvable 1 to 1 with memref for all dims +// if (inversePermutation(concatAffineMaps(maps))) { +// StringAttr empty = StringAttr::get(genericOp.getContext()); +// auto newGenericOp = rewriter.create( +// genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, +// listOfNewMaps, genericOp.getIteratorTypesArray(), empty, empty); +// rewriter.inlineRegionBefore(genericOp.getRegion(), +// newGenericOp.getRegion(), +// newGenericOp.getRegion().end()); +// +// // auto &block = newGenericOp.getRegion().front(); +// // block.addArguments(newGenericOp.getOperandTypes(), +// // SmallVector(newGenericOp.getNumOperands(), +// // genericOp.getLoc())); +// +// rewriter.replaceOp(genericOp, newGenericOp.getResults()); +// return success(); +// } +// // for(iterate over inputs) +// //{ +// // gather maps +// // gather submaps +// // Gather affine maps from submaps +// // Check over 2 iterations if all the indexes can be solved. +// // Use the same logic as linalg.generic to do this. +// // if success in getting vars +// // replace affine map from submap to linalg.generic +// // replace input memref as direct input to linalg.generic +// // } +// // assert(false && "inversePermutation doesn't exists for the given linalg +// // generic"); +// return failure(); +// } +//}; + // struct LinalgOfSubmap : public OpRewritePattern { // using OpRewritePattern::OpRewritePattern; // LogicalResult matchAndRewrite(linalg::GenericOp gen, @@ -6641,7 +6805,7 @@ void polygeist::SubmapOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { // results.insert(context); - results.insert(context); + results.insert(context); // results.insert(context); } diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index ce4154a6e6ae..cb39efc1011a 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -503,15 +503,18 @@ struct LinalgDebufferization : public OpRewritePattern { // if we are no alias we can just look at all users of the value // if we are not noalias, or we are captured, then we have to look at all users that // could read or write - if ((!isNoalias) || isCaptured(memVal)) { - return failure(); - } + //TODO: skipping noalias for now + //if ((!isNoalias) || isCaptured(memVal)) { + // return failure(); + //} MemRefType memrefType; if (auto blockArg = memVal.dyn_cast()) { memrefType = blockArg.getType().dyn_cast(); } else if (auto allocaOp = memVal.getDefiningOp()) { memrefType = allocaOp.getType(); + } else if (auto allocOp = memVal.getDefiningOp()) { + memrefType = allocOp.getType(); } else { return failure(); } @@ -522,12 +525,12 @@ struct LinalgDebufferization : public OpRewritePattern { memrefType.getShape(), memrefType.getElementType()); // Check to see if only linalg.generic are users of the Value op for now. - // TODO: Extend this - if (!llvm::all_of(memVal.getUsers(), [](Operation *op) { - return isa(op); - })) { - return failure(); - } + //// TODO: Extend this + //if (!llvm::all_of(memVal.getUsers(), [](Operation *op) { + // return isa(op) || isa(op); + // })) { + // return failure(); + //} // auto emptyTensor = // rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), @@ -633,6 +636,12 @@ struct LinalgDebufferization : public OpRewritePattern { userIdx++; expandedUserList.erase(expandedUserList.begin() + userIdx); } + else if (auto subviewOp = dyn_cast(user)) { + rewriter.setInsertionPointAfter(subviewOp); + auto newSubviewOp = rewriter.create( + subviewOp.getLoc(), subviewOp.getType(), subviewOp.getSource(), subviewOp.getOffsets(), subviewOp.getSizes(), subviewOp.getStrides()); + rewriter.replaceOp(subviewOp, newSubviewOp.getResult()); + } } //For adding yields for the last use all the way to the outer most region @@ -666,14 +675,23 @@ struct LinalgDebufferization : public OpRewritePattern { bool changed; //Fix instead of walk, just get the list of allocaOp users, so that you can easily delete ops inside SmallVector listOfAllocaOps; + SmallVector listOfAllocOps; funcOp.walk([&](memref::AllocaOp alloca) { listOfAllocaOps.push_back(alloca); }); + //TODO: Adding allocOp for now, without alias check + funcOp.walk([&](memref::AllocOp alloc) { + listOfAllocOps.push_back(alloc); + }); for (auto alloca : listOfAllocaOps) { handleMemref(alloca); } + + for (auto alloc : listOfAllocOps) { + handleMemref(alloc); + } for(auto arg: funcOp.getArguments()){ handleMemref(arg); diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index f670edd02c25..968edd70b693 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -785,17 +785,17 @@ struct AffineForOpRaising : public OpRewritePattern { if (!lbMap || lbMap.getNumResults() != 1) return failure(); - auto ub = loop.getSingleUpperBound(); - if (!ub) - return failure(); + //auto ub = loop.getSingleUpperBound(); + //if (!ub) + // return failure(); - auto lb = loop.getSingleLowerBound(); - if (!lb) - return failure(); + //auto lb = loop.getSingleLowerBound(); + //if (!lb) + // return failure(); - if (!loop.hasConstantUpperBound()) { - return failure(); - } + //if (!loop.hasConstantUpperBound()) { + // return failure(); + //} // Retrieve the step size int64_t step = loop.getStep(); @@ -810,10 +810,10 @@ struct AffineForOpRaising : public OpRewritePattern { rewriter.create(loop.getLoc(), lbMap, lbOperands); //// Ensure the bounds are constant expressions - auto ubConst = ubExpr.dyn_cast(); - auto lbConst = lbExpr.dyn_cast(); - if (!ubConst || !lbConst) - return failure(); + //auto ubConst = ubExpr.dyn_cast(); + //auto lbConst = lbExpr.dyn_cast(); + //if (!ubConst || !lbConst) + // return failure(); // Compute the loop size // int64_t loopSize = ubConst.getValue() - lbConst.getValue(); From cb34836f62a81baaff64258494a9cdf86ef0bcf9 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 31 Jul 2025 17:00:53 -0700 Subject: [PATCH 66/77] Added parallel fission pass --- lib/polygeist/Passes/RaiseToLinalg.cpp | 113 +++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 968edd70b693..a18e1cf42e98 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -1113,6 +1113,114 @@ struct AffineForOpRaising : public OpRewritePattern { } }; +struct AffineParallelFission : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp parallelOp, + PatternRewriter &rewriter) const override { + + auto module = parallelOp->getParentOfType(); + // Collect all top-level nested loops (affine.parallel or affine.for) + SmallVector nestedLoops; + Block *body = parallelOp.getBody(); + + for (auto &op : body->without_terminator()) { + if (isa(op)) { + nestedLoops.push_back(&op); + } else if (!isMemoryOrControlFlowNeutral(&op)) { + // If there are non-trivial operations at the top level, + // we can't safely perform fission + return failure(); + } + } + + // Need at least 2 nested loops to perform fission + if (nestedLoops.size() < 2) { + return failure(); + } + + // Convert reductions ArrayAttr to ArrayRef + SmallVector reductionKinds; + for (auto attr : parallelOp.getReductions()) { + auto enumAttr = cast(attr); + reductionKinds.push_back(enumAttr.getValue()); + } + + // Convert steps to ArrayRef + SmallVector stepValues; + for (auto step : parallelOp.getSteps()) { + stepValues.push_back(step); + } + + for (Operation *nestedLoop : nestedLoops) { + + // Create new parallel loops for each nested loop + rewriter.setInsertionPoint(parallelOp); + + // Create a new outer parallel loop with same bounds + auto newParallelOp = rewriter.create( + parallelOp.getLoc(), + parallelOp.getResultTypes(), + reductionKinds, + SmallVector{parallelOp.getLowerBoundsMap()}, + parallelOp.getLowerBoundsOperands(), + SmallVector{parallelOp.getUpperBoundsMap()}, + parallelOp.getUpperBoundsOperands(), + stepValues + ); + + // Move the nested loop into the new outer loop + Block *newBody = newParallelOp.getBody(); + // Remove the existing terminator + rewriter.eraseOp(newBody->getTerminator()); + + // Set insertion point to the new body before cloning + rewriter.setInsertionPointToEnd(newBody); + + // Clone the nested loop into the new body + IRMapping mapping; + // Map the induction variables (use getIVs() instead of getInductionVars()) + for (auto [oldIV, newIV] : llvm::zip(parallelOp.getIVs(), + newParallelOp.getIVs())) { + mapping.map(oldIV, newIV); + } + + // Clone the operation (it will be automatically inserted at the current insertion point) + rewriter.clone(*nestedLoop, mapping); + + // Ensure insertion point is at the end of the outer parallel loop's body + rewriter.setInsertionPointToEnd(newBody); + + // Add the terminator back + rewriter.create(parallelOp.getLoc()); + } + + // Remove the original parallel loop + rewriter.eraseOp(parallelOp); + + return success(); + } + +private: + // Helper to check if an operation has no side effects that would + // prevent loop fission + bool isMemoryOrControlFlowNeutral(Operation *op) const { + // Allow constants, arithmetic, and other side-effect-free ops + if (isa(op)) return true; + if (op->hasTrait()) return true; + + // Check if it's a pure operation (no memory effects) + if (auto effectInterface = dyn_cast(op)) { + SmallVector effects; + effectInterface.getEffects(effects); + return effects.empty(); + } + + // Conservative: if we can't prove it's safe, assume it's not + return false; + } +}; + // namespace { // struct RaiseAffineToLinalg // : public AffineRaiseToLinalgBase { @@ -1150,6 +1258,11 @@ void RaiseAffineToLinalg::runOnOperation() { RewritePatternSet patterns(&getContext()); // TODO add the existing canonicalization patterns // + subview of an affine apply -> subview + + // Add the fission pattern first (preprocessing step) + patterns.insert(&getContext()); + + // Then add the main raising pattern patterns.insert(&getContext()); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), From 53c5d144caf5c857c224c76d622477adf928981a Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 31 Jul 2025 17:22:00 -0700 Subject: [PATCH 67/77] Added pattern for parallel to seq for loops --- lib/polygeist/Passes/RaiseToLinalg.cpp | 119 ++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 10 deletions(-) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index a18e1cf42e98..da1471eb3182 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -1221,6 +1221,86 @@ struct AffineParallelFission : public OpRewritePattern { } }; +struct AffineParallelToFor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp parallelOp, + PatternRewriter &rewriter) const override { + + // Skip if there are reductions - they need special handling + if (!parallelOp.getReductions().empty()) { + return failure(); + } + + // Skip if there are result types - parallel loops with returns need special handling + if (!parallelOp.getResultTypes().empty()) { + return failure(); + } + + Location loc = parallelOp.getLoc(); + + // Get the bounds and steps + auto lowerBounds = parallelOp.getLowerBoundsMap(); + auto upperBounds = parallelOp.getUpperBoundsMap(); + auto steps = parallelOp.getSteps(); + auto lowerOperands = parallelOp.getLowerBoundsOperands(); + auto upperOperands = parallelOp.getUpperBoundsOperands(); + auto ivs = parallelOp.getIVs(); + + // Start building nested for loops from outermost to innermost + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(parallelOp); + + // Create nested affine.for loops + SmallVector forOps; + SmallVector newIVs; + + for (unsigned i = 0; i < ivs.size(); ++i) { + // Extract bounds for this dimension + auto lbMap = lowerBounds.getSliceMap(i, 1); + auto ubMap = upperBounds.getSliceMap(i, 1); + int64_t step = steps[i]; + + auto forOp = rewriter.create( + loc, + lowerOperands, lbMap, + upperOperands, ubMap, + step + ); + + forOps.push_back(forOp); + newIVs.push_back(forOp.getInductionVar()); + + // Set insertion point for next loop or body + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Move the body content from parallel to innermost for loop + Block *parallelBody = parallelOp.getBody(); + Block *targetBody = forOps.empty() ? nullptr : forOps.back().getBody(); + + if (!targetBody) { + return failure(); + } + + // Create mapping for induction variables + IRMapping mapping; + for (auto [parallelIV, newIV] : llvm::zip(ivs, newIVs)) { + mapping.map(parallelIV, newIV); + } + + // Clone operations from parallel body to for body (excluding terminator) + for (auto &op : parallelBody->without_terminator()) { + rewriter.clone(op, mapping); + } + + // Remove the original parallel loop + rewriter.eraseOp(parallelOp); + + return success(); + } +}; + // namespace { // struct RaiseAffineToLinalg // : public AffineRaiseToLinalgBase { @@ -1255,18 +1335,37 @@ struct RaiseAffineToLinalg } // namespace void RaiseAffineToLinalg::runOnOperation() { - RewritePatternSet patterns(&getContext()); - // TODO add the existing canonicalization patterns - // + subview of an affine apply -> subview + GreedyRewriteConfig config; - // Add the fission pattern first (preprocessing step) - patterns.insert(&getContext()); + // Step 1: Apply fission pattern first + { + RewritePatternSet fissionPatterns(&getContext()); + fissionPatterns.insert(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(fissionPatterns), config))) { + signalPassFailure(); + return; + } + } - // Then add the main raising pattern - patterns.insert(&getContext()); - GreedyRewriteConfig config; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + // Step 2: Apply parallel-to-for conversion + { + RewritePatternSet parallelToForPatterns(&getContext()); + parallelToForPatterns.insert(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(parallelToForPatterns), config))) { + signalPassFailure(); + return; + } + } + + // Step 3: Apply raising pattern + { + RewritePatternSet raisingPatterns(&getContext()); + raisingPatterns.insert(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(raisingPatterns), config))) { + signalPassFailure(); + return; + } + } } namespace mlir { From 60b81d20bb1ffd0103cc73a3ffd2a10c19f366f5 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 31 Jul 2025 17:38:45 -0700 Subject: [PATCH 68/77] Added raise-to-linalg-pipeline --- include/polygeist/Passes/Passes.h | 1 + include/polygeist/Passes/Passes.td | 10 ++++++++ lib/polygeist/Passes/RaiseToLinalg.cpp | 32 ++++++++++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index c1cea4c2ec72..e70660153540 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -32,6 +32,7 @@ std::unique_ptr createOpenMPOptPass(); std::unique_ptr createCanonicalizeForPass(); std::unique_ptr createRaiseSCFToAffinePass(); std::unique_ptr createRaiseAffineToLinalgPass(); +std::unique_ptr createRaiseAffineToLinalgPipelinePass(); std::unique_ptr createLinalgDebufferizePass(); std::unique_ptr createRemoveIterArgsPass(); std::unique_ptr createCPUifyPass(StringRef method = ""); diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 4945396d6178..eef142f6dbef 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -181,6 +181,16 @@ def AffineRaiseToLinalg : Pass<"raise-affine-to-linalg"> { ]; } +def AffineRaiseToLinalgPipeline : Pass<"raise-affine-to-linalg-pipeline"> { + let summary = "Pipeline: affine-parallelize followed by raise-affine-to-linalg"; + let constructor = "mlir::polygeist::createRaiseAffineToLinalgPipelinePass()"; + let dependentDialects = [ + "affine::AffineDialect", + "linalg::LinalgDialect", + "polygeist::PolygeistDialect", + ]; +} + def SCFCanonicalizeFor : Pass<"canonicalize-scf-for"> { let summary = "Run some additional canonicalization for scf::for"; let constructor = "mlir::polygeist::createCanonicalizeForPass()"; diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index da1471eb3182..c5f57ca01cc0 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -1,6 +1,7 @@ #include "PassDetails.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -12,6 +13,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "polygeist/Passes/Passes.h" @@ -1327,6 +1329,32 @@ struct AffineParallelToFor : public OpRewritePattern { // }; // } // namespace +namespace { +struct RaiseAffineToLinalgPipeline + : public AffineRaiseToLinalgPipelineBase { + void runOnOperation() override; +}; +} // namespace + +void RaiseAffineToLinalgPipeline::runOnOperation() { + // Create a nested pass manager to run the pipeline on functions + OpPassManager pm(getOperation()->getName()); + + // Create a nested pass manager for function operations + OpPassManager &funcPM = pm.nest(); + + // Add affine-parallelize pass first (runs on func.func) + funcPM.addPass(mlir::affine::createAffineParallelizePass()); + + // Add our raise-affine-to-linalg pass second (also runs on func.func) + funcPM.addPass(createRaiseAffineToLinalgPass()); + + // Run the pipeline + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } +} + namespace { struct RaiseAffineToLinalg : public AffineRaiseToLinalgBase { @@ -1373,5 +1401,9 @@ namespace polygeist { std::unique_ptr createRaiseAffineToLinalgPass() { return std::make_unique(); } + +std::unique_ptr createRaiseAffineToLinalgPipelinePass() { + return std::make_unique(); +} } // namespace polygeist } // namespace mlir From 7b2f5d9cba3bcb3c1003d8cf6a3cac0eb52d4a13 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Thu, 31 Jul 2025 19:58:04 -0700 Subject: [PATCH 69/77] Added linalgGenericEliminateSubmaps and commented out submapToSubviewOp --- lib/polygeist/Ops.cpp | 552 +++++++----------------------------------- 1 file changed, 84 insertions(+), 468 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index f9a607b66db2..a0168917e59b 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -6088,480 +6088,29 @@ analyzeSubmapToSubViewConversion(polygeist::SubmapOp submapOp) { } -struct SubmapToSubviewOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(polygeist::SubmapOp submapOp, - PatternRewriter &rewriter) const override { - auto conversionInfo = analyzeSubmapToSubViewConversion(submapOp); - if (!conversionInfo.isValid) - return failure(); - - SmallVector offsetValues, sizeValues, strideValues; - for (int64_t offset : conversionInfo.offsets) { - offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); - } - for (int64_t stride : conversionInfo.strides) { - strideValues.push_back(rewriter.getI64IntegerAttr(stride)); - } - for (Value size : conversionInfo.sizes) { - sizeValues.push_back(size); - } - //auto subViewOp = rewriter.create( - // source_memref.getLoc(), // Location - // source_memref, // Source memref - // offsetValues, // Offsets (array) - // sizeValues, // Sizes (array) - // strideValues // Strides (array) - //); - rewriter.replaceOpWithNewOp(submapOp, submapOp.getType(), submapOp.getMemref(), offsetValues, sizeValues, strideValues); - return success(); - } -}; - -//struct LinalgOfSubmap : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(linalg::GenericOp genericOp, +//struct SubmapToSubviewOp : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; +// LogicalResult matchAndRewrite(polygeist::SubmapOp submapOp, // PatternRewriter &rewriter) const override { -// // Check body content -// auto module = genericOp->getParentOfType(); -// Region &genericBody = genericOp.getRegion(); -// Block &entryBlock = genericBody.front(); -// ValueRange blockArgs = entryBlock.getArguments(); -// auto inputs = genericOp.getInputs(); -// auto outputs = genericOp.getOutputs(); -// SmallVector listOfAllocas; -// SmallVector listOfNewMaps; -// SmallVector listOfNewInputs, listOfNewOutputs; -// -// // auto mapAttrsArr = genericOp.getIndexingMaps(); -// // for(auto mapAttr: mapAttrsArr) { -// // AffineMap map = mapAttr.cast().getValue(); -// // if(map == convMap[0] && !mapped[0]) { -// // } -// // } -// for (auto inp : inputs) { -// if (auto blkArg = dyn_cast(inp)) { -// listOfNewInputs.push_back(inp); -// } else if (auto subMap = -// dyn_cast(inp.getDefiningOp())) { -// auto source_memref = subMap.getMemref(); +// auto conversionInfo = analyzeSubmapToSubViewConversion(submapOp); +// if (!conversionInfo.isValid) +// return failure(); // -// //Create a new memref.subview op from the given submap and sizes -// Value stride = rewriter.create(source_memref.getLoc(), 1); -// -// //sizesauto blockArg = dyn_cast_or_null(op)) { -// // if(auto source_alloca = -// // dyn_cast(source_memref.getDefiningOp())) -// //{ -// auto map = subMap.getMap(); -// -// ////Create sizes from the submap -// auto sizes = subMap.getSizes(); -// -// // Create a subview op using lower bound, stride and size -// // Convert AffineApplyOp to its result Value and wrap in ValueRange -// auto [strides, lowerBounds] = extractStridesAndBounds(map); -// SmallVector offsetValues, sizeValues, strideValues; -// for (int64_t offset : lowerBounds) { -// offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); -// } -// for (int64_t stride : strides) { -// strideValues.push_back(rewriter.getI64IntegerAttr(stride)); -// } -// for (Value size : sizes) { -// sizeValues.push_back(size); -// } -// auto subViewOp = rewriter.create( -// source_memref.getLoc(), // Location -// source_memref, // Source memref -// offsetValues, // Offsets (array) -// sizeValues, // Sizes (array) -// strideValues // Strides (array) -// ); -// auto subViewType = subViewOp.getType().cast(); -// unsigned rank = subViewType.getRank(); -// auto identityMap = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); -// -// listOfNewMaps.push_back(identityMap); -// listOfNewInputs.push_back(subViewOp); -// -// //} -// // else { -// // assert(false && "Only expect allocaOp as source for submap -// // canonicalization right now"); return failure(); -// //} -// } else { -// listOfNewInputs.push_back(inp); -// } +// SmallVector offsetValues, sizeValues, strideValues; +// for (int64_t offset : conversionInfo.offsets) { +// offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); // } -// -// for (auto out : outputs) { -// if (auto blkArg = dyn_cast(out)) { -// listOfNewOutputs.push_back(out); -// } else if (auto subMap = -// dyn_cast(out.getDefiningOp())) { -// auto source_memref = subMap.getMemref(); -// auto map = subMap.getMap(); -// -// //Create sizes from the submap -// auto sizes = subMap.getSizes(); -// -// // Create a subview op using lower bound, stride and size -// // Convert AffineApplyOp to its result Value and wrap in ValueRange -// auto [strides, lowerBounds] = extractStridesAndBounds(map); -// -// SmallVector offsetValues, sizeValues, strideValues; -// for (int64_t offset : lowerBounds) { -// offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); -// } -// for (int64_t stride : strides) { -// strideValues.push_back(rewriter.getI64IntegerAttr(stride)); -// } -// for (Value size : sizes) { -// sizeValues.push_back(size); -// } -// auto subViewOp = rewriter.create( -// source_memref.getLoc(), // Location -// source_memref, // Source memref -// offsetValues, // Offsets (array) -// sizeValues, // Sizes (array) -// strideValues // Strides (array) -// ); -// auto subViewType = subViewOp.getType().cast(); -// unsigned rank = subViewType.getRank(); -// auto identityMap = AffineMap::getMultiDimIdentityMap(rank, rewriter.getContext()); -// listOfNewMaps.push_back(identityMap); -// listOfNewOutputs.push_back(subViewOp); -// } else { -// listOfNewOutputs.push_back(out); -// } +// for (int64_t stride : conversionInfo.strides) { +// strideValues.push_back(rewriter.getI64IntegerAttr(stride)); // } -// ArrayRef maps(listOfNewMaps); -// // No submap ops detected -// if (maps.size() == 0) -// return failure(); -// // If inverse permutation exists, then we can canonicalize the linalg of -// // submap to linalg -// // TODO: Fails for: -// // 1. Maps with symbols -// // 2. Maps which are not resolvable 1 to 1 with memref for all dims -// if (inversePermutation(concatAffineMaps(maps))) { -// StringAttr empty = StringAttr::get(genericOp.getContext()); -// auto newGenericOp = rewriter.create( -// genericOp.getLoc(), TypeRange(), listOfNewInputs, listOfNewOutputs, -// listOfNewMaps, genericOp.getIteratorTypesArray(), empty, empty); -// rewriter.inlineRegionBefore(genericOp.getRegion(), -// newGenericOp.getRegion(), -// newGenericOp.getRegion().end()); -// -// // auto &block = newGenericOp.getRegion().front(); -// // block.addArguments(newGenericOp.getOperandTypes(), -// // SmallVector(newGenericOp.getNumOperands(), -// // genericOp.getLoc())); -// -// rewriter.replaceOp(genericOp, newGenericOp.getResults()); -// return success(); +// for (Value size : conversionInfo.sizes) { +// sizeValues.push_back(size); // } -// // for(iterate over inputs) -// //{ -// // gather maps -// // gather submaps -// // Gather affine maps from submaps -// // Check over 2 iterations if all the indexes can be solved. -// // Use the same logic as linalg.generic to do this. -// // if success in getting vars -// // replace affine map from submap to linalg.generic -// // replace input memref as direct input to linalg.generic -// // } -// // assert(false && "inversePermutation doesn't exists for the given linalg -// // generic"); -// return failure(); +// rewriter.replaceOpWithNewOp(submapOp, submapOp.getType(), submapOp.getMemref(), offsetValues, sizeValues, strideValues); +// return success(); // } //}; -// struct LinalgOfSubmap : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(linalg::GenericOp gen, -// PatternRewriter &rewriter) const override { - -// // Canonicalization 1 linalg.generic of map of submap. -> linalg.generic -// of map of submap -// //. iff the submap's affine map != identity -// //. replace inner affine map with composition - -// // Canonicalizeation 3: submap which only sets bounds, of an input memref -// with the same bounds -> noop / cast - -// // Canonicalization 1.5 (mix of 1/2) -// //. linalg_map = identity a[i,j,x,y] -> u[i+x][j+y] -// //. linalg_map = [i,j,x,y]->(i+x,j+y) a[i,j] -> u[i,j]. # but still -// keeping the upper loop limit -// //. 1 - -// // a[i] -> x[] - -// // a[1] -> x[] -// // a[2] -> x[] - -// // a[i,j] = x[map(i,j)]. ; the subbmap op -// //a[i+x][j+y] : submap defines iteration var 0 goes from 0 ... A0. and -// var 1 goes from 0 ... A1 -// //b[x][y] -// //c[i+x][j+y] -// // here we have 4 iteration variables that linalg is doing i, j, x, y -// // for (i : ...) -// //. for (j : ...) -// //. for (x : ...) -// //. for (y : ...) -// // c[i+x][j+y] += a[i+x][j+y] * b[x][y] - -// // a[i+x][j+y] -// // c[i+x][j+y] -// // for (i : ...) -// //. for (j : ...) -// //. for (x : ...) -// //. for (y : ...) -// // c[i+x][j+y] += a[i+x][j+y] - -// //x[map(i+x,j+y)] pass in the outermost one with correspondidng composed -// maps -// //b[x][y] -// //c[i+x][j+y] - -// // requirement here, is that all linalg.generic loop bounds must be -// solvable after replacement -// // for example, this would not be permissible -// // a[i] -> x[]. ; a = submap memref -> memref<100xf32> -// // out[] - -// // This cannot be replaced since now the linalg generic iteration variable -// i cannot be solved for - -// for (auto &&[op, opmap] : gen.getInputsAndMaps()) { -// if (auto submap = op.getDefiningOp()) { -// bool solvable = false; - -// /// Cannoicalization 2: index removal -// //. x[i, j] -> v[i]. can we get rid of j? -// //. Are input indices defined by other ops, and if so, can we -// simplify -// //. 1) Take all other input memrefs -// // 2) Determine all solvable indices from those input memrefs -// //. For each index which is solvable from 2) -// // if it can either be removed from the submap, or combined -// with another index in the submap, -// // remove it from the submap - -// SmallVector exprs; -// for (auto [op2, map] : gen.getInputAndMaps()) { -// if (op != op2) { -// for (auto expr : map.getAffineExprs()) { -// exprs.push_back(expr); -// } -// } -// } -// for (auto [op2, map] : gen.getOutputAndMaps()) { -// if (op != op2) { -// for (auto expr : map.getAffineExprs()) { -// exprs.push_back(expr); -// } -// } -// } -// SmallSet solvable; -// linalg.determineSolvableIndices(solvable, exprs); - -// SmallSet notsolvable = allvariables - solvable; - -// //a[map(i+x,j+y)] pass in the outermost one with correspondidng -// composed maps -// //b[x][y] -// //c[i+x][j+y] -// // Supose we're solving for a -// // Here exprs would contain all the affineexprs from b and c. (aka -// inputs - {x}) - -// // {x, y, i+x, j+y} -// // Running a solver allows us to uniquely solve for all of, x, y, i, -// and j with these expressoin -// // In this case we can attempt to remove dependence on x, y, i, j - -// // If however we had -// //a[map(i+x,j+y)] pass in the outermost one with correspondidng -// composed maps -// //b[x][y] -// //c[i+x][y] -// // we would solve with {x, y, i+x, y} -// // Running a solver we would be able to sole for {x, y, i} but not -// solve for j -// // In this case we can attempt to remove dependence on x, y, i, but -// not on j - -// // let's take easiest one where a is just broadcasting a constant to -// all input indices -// // a = submap (m,n) -> u[] -// // a[i+x, j+y] For all input indices which are uniquely solvable, here -// that is both -// //. index 0 = i + x -// //. and index 1 = j + y -// // set the input map to compose with the submap's affine map - -// /// Easy special case -// if (notsolvable.size() == 0) { - -// replace opmap with submap.compose(opmap) taking into account the the -// ConstantIntRanges -// // Easy case -// } - -// // We now have two maps with different meanings -// // Let |N| be the number of loop variables in the linalg.generic -// // Let |M| be length(submap.getType().getShape()) -// // Let |Q| be length(submap.getInput().getType().getShape()), number -// of dimensions of input operand to the submap - -// // opmap from the linalg.generic which takes linalg.generic loop -// indices |N| -> inputs to the submap op. |M| - -// // submap.map. submap op. which takes input indices |M|. -// -> indices for the corresponing base memref |Q| - -// // Example - -// //a[map(i+x,j+y)] pass in the outermost one with correspondidng -// composed maps -// //b[x][y] -// //c[i+x][j+y] - -// // a = submap (w,p) -> u[c + 2 * p] - -// // %c = myop.constant() -// // %a = submap a[w, p] -> u[%c + 2 * p] -// //. linalg.generic %a %b %c a.map (x,y,i,j) -> a[x+i,y+j] { -// // } - -// // N = 4 = |{i,j,x,u}| -// // M = 2 = dim(a) . a is 2 dims -// // Q = 1. dim(u) - -// SmallVector newLinalgExprs; -// SmallVector newSubmapExprs; - -// SmallVector legalIndices; -// // We iterate for all |M| expressions of the opmap -// for (auto &&[i, linalgexpr] : llvm::enumerate(opmap.getExprs())) { -// // We must retain the indexing for variables which are functions -// // of the inputs which have a defining index. -// bool legal = true; -// for (auto var : notsolvable) { -// if (linalgexpr.isFunctionOf(var)) { -// legal = false; -// // we can pop this from the not solvable since now this index -// will define -// // the value of var for future iterations. -// // But doing so requires proving it is not a linear -// combination of previously -// // visited linalgexpr's, so we'll defer this for a later -// optimization -// // notsolvable.pop(var); -// } -// } - -// if (legal) -// legalIndices.push_back(i); -// } - -// // The non-special case version -// // j is not solvable -// //a[map(i+x,j+y)] pass in the outermost one with correspondidng -// composed maps -// //b[x][y] -// //c[i+x][y] - -// // because j is not solvable we cannot move any expressions depending -// on j (in this case p depends on j) -// //. and the underlying sub expressions depending j, in this case via -// p are: -// // a[1] = w + 4 and a[2] = w + 7 -// // define a(w,p) -> u[c + 2 * p, w + 4, w + 7] - -// // with the general case optimization v0. [just moving expressions up] - -// //a2[map(i+x, j+y), i + x + 4, i + x + 7] pass in the outermost one -// with correspondidng composed maps -// //b[x][y] -// //c[i+x][y] - -// // define a2(w, p) -> u[c + 2 * p] - -// // with the general case optimization v1. [just eliminating -// unnecessary indices] - -// //a2[map(j+y), i + x + 4, i + x + 7] pass in the outermost one with -// correspondidng composed maps -// //b[x][y] -// //c[i+x][y] - -// // define a2(p) -> u[c + 2 * p] - -// // So this optimization generally moves expression from the submap -// into the linalg map -// // and it it also removes unnecessary indices into the submap - -// // If the entire submap is legal to inline, the solution is simple, -// replace the linalg -// // map with itself composed with the submap, and replace the original -// submap with the identity op if (legalIndices.size() == -// opmap.getExprs().size()) { -// // Note, it isn't 100% as simple as below since we still need to -// retain any constant op's in the -// // new submap op below, since linalg.generic doesn't support -// constant value's for the indexing, as far -// // as I (wmoses) know? -// newLinalgExprs = opmap.compose(submap.getMap()).getExprs(); -// newSubmapExprs = -// Affine::getIdentityMap(submap.getOperand().getShape().size()).getExprs(); -// } else { -// SmallVector illegalIndices = allIndices - legalIndices; - -// // We can alternatively re-index maps which are solely functions of -// legal indices. for (auto &&[i, submapexpr] : -// llvm::enumerate(submap.getAffineMap().getExprs())) { -// if (submapexpr is a function of any illegal indicies) { -// // we need to keep this as a submap expr (though re-indexed on -// the new number of exprs) -// newSubmapExprs.push_back(submapexpr.reindex()); -// } else { -// // this index can be completely solved for with other inputs, -// let's move the expression from -// // a submap expression into a linalg.generic map expression. -// newLinalgExprs.push_back(opmap.compose(submapexpr)); -// newSubmapExprs.push_back(Affine::getIdentity()); -// } -// } -// } - -// if (solvable) { -// // replace the input to the generic with the input to the submap, -// and the new map return success(); -// } -// } -// } - -// for (auto op : gen.getOutputs()) { -// if (auto submap = op.getDefiningOp()) { -// bool solvable = false; -// if (solvable) { -// do the thing -// // replace the input to the generic with the input to the submap, -// and the new map return success(); -// } -// } -// } - -// return failure(); -// } -// }; - static llvm::cl::opt BufferElim("enable-buffer-elim", llvm::cl::init(true), llvm::cl::desc("Enable buffer elimination")); @@ -6593,7 +6142,6 @@ void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results, SimplifyDeadAllocV2, SimplifyDeadAllocV2, MulDivMul, MergeParallelInductions, - // RankReduction, AggressiveAllocaScopeInliner, InductiveVarRemoval>(context); } @@ -6801,11 +6349,79 @@ class DimSubMap final : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// LinalgGenericEliminateSubmaps Pattern +//===----------------------------------------------------------------------===// + +struct LinalgGenericEliminateSubmaps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { + bool hasSubmaps = false; + SmallVector newInputs; + SmallVector newOutputs; + SmallVector newIndexingMaps; + + // Get the indexing maps as AffineMap array + auto indexingMaps = genericOp.getIndexingMapsArray(); + + // Check inputs for submaps + for (auto [input, map] : llvm::zip(genericOp.getInputs(), indexingMaps)) { + if (auto submapOp = input.getDefiningOp()) { + hasSubmaps = true; + newInputs.push_back(submapOp.getViewSource()); + // Compose: submap_map.compose(linalg_map) → f(g(x)) + AffineMap composedMap = submapOp.getMap().compose(map); + newIndexingMaps.push_back(composedMap); + } else { + newInputs.push_back(input); + newIndexingMaps.push_back(map); + } + } + + // Check outputs for submaps + auto outputMaps = ArrayRef(indexingMaps).drop_front(genericOp.getInputs().size()); + for (auto [output, map] : llvm::zip(genericOp.getOutputs(), outputMaps)) { + if (auto submapOp = output.getDefiningOp()) { + hasSubmaps = true; + newOutputs.push_back(submapOp.getViewSource()); + // Compose: submap_map.compose(linalg_map) → f(g(x)) + AffineMap composedMap = submapOp.getMap().compose(map); + newIndexingMaps.push_back(composedMap); + } else { + newOutputs.push_back(output); + newIndexingMaps.push_back(map); + } + } + + if (!hasSubmaps) { + return failure(); + } + + // Create new linalg.generic with composed maps + auto newGenericOp = rewriter.create( + genericOp.getLoc(), + genericOp.getResultTypes(), + newInputs, + newOutputs, + newIndexingMaps, + genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr); + + // Clone the region + IRMapping mapping; + genericOp.getRegion().cloneInto(&newGenericOp.getRegion(), mapping); + + rewriter.replaceOp(genericOp, newGenericOp.getResults()); + return success(); + } +}; + void polygeist::SubmapOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { // results.insert(context); - results.insert(context); + results.insert(context); // results.insert(context); } From 71e441f561a8fb7344bf1a3428674368d224867d Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 1 Aug 2025 11:23:59 -0700 Subject: [PATCH 70/77] Canonicalization fix --- lib/polygeist/Ops.cpp | 330 +++++++++++++++++++++++-- lib/polygeist/Passes/RaiseToLinalg.cpp | 21 +- 2 files changed, 316 insertions(+), 35 deletions(-) diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index a0168917e59b..b203bdcce137 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -818,7 +818,7 @@ bool mayAlias(Value v, Value v2) { isAlloca[1] = isStackAlloca(v2); isGlobal[1] = v2.getDefiningOp() || - v2.getDefiningOp(); + v2.getDefiningOp(); // Non-equivalent allocas/global's cannot conflict with each other if ((isAlloca[0] || isGlobal[0]) && (isAlloca[1] || isGlobal[1])) @@ -5991,7 +5991,12 @@ static bool canConvertSubmapToSubView(polygeist::SubmapOp submapOp) { auto sizes = submapOp.getSizes(); auto symbols = submapOp.getSymbols(); auto source_memref = submapOp.getMemref(); - + + // 0. Only convert if map has symbols + if (submapOp.getMap().getNumSymbols() == 0) { + return false; + } + // 1. Identity maps are always valid if (map.isIdentity()) { return true; @@ -6088,28 +6093,288 @@ analyzeSubmapToSubViewConversion(polygeist::SubmapOp submapOp) { } -//struct SubmapToSubviewOp : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(polygeist::SubmapOp submapOp, -// PatternRewriter &rewriter) const override { -// auto conversionInfo = analyzeSubmapToSubViewConversion(submapOp); -// if (!conversionInfo.isValid) -// return failure(); -// -// SmallVector offsetValues, sizeValues, strideValues; -// for (int64_t offset : conversionInfo.offsets) { -// offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); -// } -// for (int64_t stride : conversionInfo.strides) { -// strideValues.push_back(rewriter.getI64IntegerAttr(stride)); -// } -// for (Value size : conversionInfo.sizes) { -// sizeValues.push_back(size); -// } -// rewriter.replaceOpWithNewOp(submapOp, submapOp.getType(), submapOp.getMemref(), offsetValues, sizeValues, strideValues); -// return success(); -// } -//}; +struct SubmapToSubviewOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(polygeist::SubmapOp submapOp, + PatternRewriter &rewriter) const override { + auto conversionInfo = analyzeSubmapToSubViewConversion(submapOp); + if (!conversionInfo.isValid) + return failure(); + + SmallVector offsetValues, sizeValues, strideValues; + for (int64_t offset : conversionInfo.offsets) { + offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); + } + for (int64_t stride : conversionInfo.strides) { + strideValues.push_back(rewriter.getI64IntegerAttr(stride)); + } + for (Value size : conversionInfo.sizes) { + sizeValues.push_back(size); + } + rewriter.replaceOpWithNewOp(submapOp, submapOp.getType(), submapOp.getMemref(), offsetValues, sizeValues, strideValues); + return success(); + } +}; + +// Enhanced analysis structure to handle symbols and transposes +struct EnhancedSubmapAnalysis { + bool isValid = false; + bool needsTranspose = false; + SmallVector permutation; // For transpose: [1,0] means swap dims + SmallVector offsets; // Mix of constants and symbol values + SmallVector strides; // Mix of constants and symbol values + SmallVector sizes; // From submapOp.getSizes() +}; + +// Helper to analyze affine expressions with symbol support +static bool analyzeExpressionWithSymbols(AffineExpr expr, unsigned expectedDim, + ValueRange symbolValues, + OpFoldResult &offset, OpFoldResult &stride, + unsigned &actualDim, OpBuilder &builder) { + offset = builder.getI64IntegerAttr(0); // Default offset = 0 + stride = builder.getI64IntegerAttr(1); // Default stride = 1 + actualDim = expectedDim; + + // Case 1: Simple dimension access: d0, d1, etc. + if (auto dimExpr = expr.dyn_cast()) { + actualDim = dimExpr.getPosition(); + return true; + } + + // Case 2: Constant (pure offset) + if (auto constExpr = expr.dyn_cast()) { + offset = builder.getI64IntegerAttr(constExpr.getValue()); + actualDim = 0; // Degenerate case + return true; + } + + // Case 3: Symbol (pure offset from symbol) + if (auto symbolExpr = expr.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + offset = symbolValues[symbolExpr.getPosition()]; + actualDim = 0; // Degenerate case + return true; + } + return false; + } + + // Case 4: Binary operations + if (auto binaryExpr = expr.dyn_cast()) { + auto lhs = binaryExpr.getLHS(); + auto rhs = binaryExpr.getRHS(); + + if (binaryExpr.getKind() == AffineExprKind::Add) { + // d0 + constant, d0 + symbol, constant + symbol, etc. + if (auto dimExpr = lhs.dyn_cast()) { + actualDim = dimExpr.getPosition(); + if (auto constExpr = rhs.dyn_cast()) { + offset = builder.getI64IntegerAttr(constExpr.getValue()); + return true; + } + if (auto symbolExpr = rhs.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + offset = symbolValues[symbolExpr.getPosition()]; + return true; + } + } + } + // Try reverse: constant + d0, symbol + d0 + if (auto dimExpr = rhs.dyn_cast()) { + actualDim = dimExpr.getPosition(); + if (auto constExpr = lhs.dyn_cast()) { + offset = builder.getI64IntegerAttr(constExpr.getValue()); + return true; + } + if (auto symbolExpr = lhs.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + offset = symbolValues[symbolExpr.getPosition()]; + return true; + } + } + } + } + + if (binaryExpr.getKind() == AffineExprKind::Mul) { + // d0 * constant, d0 * symbol + if (auto dimExpr = lhs.dyn_cast()) { + actualDim = dimExpr.getPosition(); + if (auto constExpr = rhs.dyn_cast()) { + stride = builder.getI64IntegerAttr(constExpr.getValue()); + return true; + } + if (auto symbolExpr = rhs.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + stride = symbolValues[symbolExpr.getPosition()]; + return true; + } + } + } + // Try reverse: constant * d0, symbol * d0 + if (auto dimExpr = rhs.dyn_cast()) { + actualDim = dimExpr.getPosition(); + if (auto constExpr = lhs.dyn_cast()) { + stride = builder.getI64IntegerAttr(constExpr.getValue()); + return true; + } + if (auto symbolExpr = lhs.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + stride = symbolValues[symbolExpr.getPosition()]; + return true; + } + } + } + } + } + + return false; +} + +// Enhanced analysis function +static EnhancedSubmapAnalysis analyzeEnhancedSubmap(polygeist::SubmapOp submapOp, + OpBuilder &builder) { + EnhancedSubmapAnalysis analysis; + auto map = submapOp.getMap(); + auto symbolValues = submapOp.getSymbols(); + auto sizes = submapOp.getSizes(); + auto sourceType = submapOp.getViewSource().getType().cast(); + int64_t sourceRank = sourceType.getRank(); + + // Only handle maps with reasonable complexity + if (map.getNumResults() == 0 || map.getNumResults() > 4) { + return analysis; + } + + // Initialize arrays with default values for all dimensions of source memref + SmallVector offsets(sourceRank, builder.getI64IntegerAttr(0)); + SmallVector strides(sourceRank, builder.getI64IntegerAttr(1)); + SmallVector resultSizes; + SmallVector actualDims; + + // Build default sizes from source memref shape + for (int64_t i = 0; i < sourceRank; ++i) { + int64_t dimSize = sourceType.getDimSize(i); + if (dimSize == ShapedType::kDynamic) { + // For dynamic dimensions, we need to use the actual size + Value dimSizeValue = builder.create( + submapOp.getLoc(), submapOp.getViewSource(), i); + resultSizes.push_back(dimSizeValue); + } else { + resultSizes.push_back(builder.getI64IntegerAttr(dimSize)); + } + } + + // Analyze each result expression and update corresponding dimension + for (unsigned i = 0; i < map.getNumResults(); ++i) { + auto expr = map.getResult(i); + OpFoldResult offset, stride; + unsigned actualDim; + + if (!analyzeExpressionWithSymbols(expr, i, symbolValues, offset, stride, + actualDim, builder)) { + return analysis; // Failed to analyze + } + + // Make sure actualDim is within bounds + if (actualDim >= sourceRank) { + return analysis; // Invalid dimension + } + + // Update the arrays for this dimension + offsets[actualDim] = offset; + strides[actualDim] = stride; + actualDims.push_back(actualDim); + } + + analysis.isValid = true; + analysis.offsets = std::move(offsets); + analysis.strides = std::move(strides); + + // Copy sizes - use provided sizes if available, otherwise use computed ones + if (sizes.size() == map.getNumResults()) { + for (auto size : sizes) { + analysis.sizes.push_back(size); + } + } else { + // Use default sizes for all dimensions + analysis.sizes = std::move(resultSizes); + } + + return analysis; +} + +// Enhanced pattern implementation +struct EnhancedSubmapToSubviewOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(polygeist::SubmapOp submapOp, + PatternRewriter &rewriter) const override { + auto analysis = analyzeEnhancedSubmap(submapOp, rewriter); + if (!analysis.isValid) { + return failure(); + } + + Value currentMemref = submapOp.getViewSource(); + Location loc = submapOp.getLoc(); + + // Step 1: Apply subview if we have non-trivial offsets/strides + bool hasNonTrivialSubview = false; + for (auto offset : analysis.offsets) { + if (auto attr = offset.dyn_cast()) { + if (auto intAttr = attr.dyn_cast()) { + if (intAttr.getInt() != 0) { + hasNonTrivialSubview = true; + break; + } + } + } else { + hasNonTrivialSubview = true; // Non-constant offset + break; + } + } + + for (auto stride : analysis.strides) { + if (auto attr = stride.dyn_cast()) { + if (auto intAttr = attr.dyn_cast()) { + if (intAttr.getInt() != 1) { + hasNonTrivialSubview = true; + break; + } + } + } else { + hasNonTrivialSubview = true; // Non-constant stride + break; + } + } + + if (hasNonTrivialSubview) { + // Create subview operation + auto subviewOp = rewriter.create( + loc, currentMemref, analysis.offsets, analysis.sizes, analysis.strides); + currentMemref = subviewOp.getResult(); + } + + // Step 2: Apply transpose if needed + if (analysis.needsTranspose) { + // Create transpose using linalg.transpose or memref.transpose + // For now, let's use a simple approach with linalg + SmallVector permutation = analysis.permutation; + + // Create transpose using linalg.transpose (if available) + // This is a simplified version - you might need to adjust based on available ops + auto transposeType = MemRefType::get( + submapOp.getType().cast().getShape(), + submapOp.getType().cast().getElementType()); + + // For simplicity, let's create an identity operation for now + // In practice, you'd want to create the actual transpose operation + currentMemref = currentMemref; // TODO: Implement actual transpose + } + + // Replace the original submap + rewriter.replaceOp(submapOp, currentMemref); + return success(); + } +}; static llvm::cl::opt BufferElim("enable-buffer-elim", llvm::cl::init(true), @@ -6368,6 +6633,13 @@ struct LinalgGenericEliminateSubmaps : public OpRewritePattern()) { + // Skip submaps with symbols for now to avoid invalid map composition + if (submapOp.getMap().getNumSymbols() > 0) { + newInputs.push_back(input); + newIndexingMaps.push_back(map); + continue; + } + hasSubmaps = true; newInputs.push_back(submapOp.getViewSource()); // Compose: submap_map.compose(linalg_map) → f(g(x)) @@ -6383,6 +6655,13 @@ struct LinalgGenericEliminateSubmaps : public OpRewritePattern(indexingMaps).drop_front(genericOp.getInputs().size()); for (auto [output, map] : llvm::zip(genericOp.getOutputs(), outputMaps)) { if (auto submapOp = output.getDefiningOp()) { + // Skip submaps with symbols for now to avoid invalid map composition + if (submapOp.getMap().getNumSymbols() > 0) { + newOutputs.push_back(output); + newIndexingMaps.push_back(map); + continue; + } + hasSubmaps = true; newOutputs.push_back(submapOp.getViewSource()); // Compose: submap_map.compose(linalg_map) → f(g(x)) @@ -6421,7 +6700,8 @@ void polygeist::SubmapOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { // results.insert(context); - results.insert(context); + results.insert(context); // results.insert(context); } diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index c5f57ca01cc0..6698fb2831e0 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -16,6 +16,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" #include "polygeist/Passes/Passes.h" #include "llvm/Support/Debug.h" @@ -1129,9 +1130,8 @@ struct AffineParallelFission : public OpRewritePattern { for (auto &op : body->without_terminator()) { if (isa(op)) { nestedLoops.push_back(&op); - } else if (!isMemoryOrControlFlowNeutral(&op)) { - // If there are non-trivial operations at the top level, - // we can't safely perform fission + } else { + // Only allow pure nested loops - reject any other operations return failure(); } } @@ -1349,9 +1349,13 @@ void RaiseAffineToLinalgPipeline::runOnOperation() { // Add our raise-affine-to-linalg pass second (also runs on func.func) funcPM.addPass(createRaiseAffineToLinalgPass()); + // Canonicalize after raise-to-linalg to eliminate submaps and other patterns + funcPM.addPass(createCanonicalizerPass()); + // Run the pipeline if (failed(runPipeline(pm, getOperation()))) { - signalPassFailure(); + // Warn but don't fail the pass - convergence issues shouldn't kill output + getOperation()->emitWarning("Pipeline didn't converge completely, but continuing anyway"); } } @@ -1370,8 +1374,7 @@ void RaiseAffineToLinalg::runOnOperation() { RewritePatternSet fissionPatterns(&getContext()); fissionPatterns.insert(&getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(fissionPatterns), config))) { - signalPassFailure(); - return; + getOperation()->emitWarning("AffineParallelFission didn't converge, continuing anyway"); } } @@ -1380,8 +1383,7 @@ void RaiseAffineToLinalg::runOnOperation() { RewritePatternSet parallelToForPatterns(&getContext()); parallelToForPatterns.insert(&getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(parallelToForPatterns), config))) { - signalPassFailure(); - return; + getOperation()->emitWarning("AffineParallelToFor didn't converge, continuing anyway"); } } @@ -1390,8 +1392,7 @@ void RaiseAffineToLinalg::runOnOperation() { RewritePatternSet raisingPatterns(&getContext()); raisingPatterns.insert(&getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(raisingPatterns), config))) { - signalPassFailure(); - return; + getOperation()->emitWarning("AffineForOpRaising didn't converge, continuing anyway"); } } } From e421a866a58c07e23860e9e5996099585cf64994 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 1 Aug 2025 12:57:06 -0700 Subject: [PATCH 71/77] bug fix for non nullptr in submap creation --- lib/polygeist/Passes/RaiseToLinalg.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 6698fb2831e0..7182cb4b0cca 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -349,6 +349,10 @@ Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, DenseSet processedOps; IRMapping mapping; auto clonedOp = recursiveCloneWithDominanceCheck(builder, sz, builder.getBlock()->getParent(), mapping, processedOps); + if (!clonedOp) { + legal = false; + return nullptr; + } operands_without_indices.push_back(clonedOp); } From 56724a52cef2b9f2d01d984ea9416a60436ddc9a Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 1 Aug 2025 15:30:01 -0700 Subject: [PATCH 72/77] Fix in linalg debufferizer - failure return and only insert memref.copy if current!=totensor --- lib/polygeist/Passes/LinalgDebufferize.cpp | 48 +++++++++++++++------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index cb39efc1011a..1a4e22e39dec 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -535,11 +535,17 @@ struct LinalgDebufferization : public OpRewritePattern { // auto emptyTensor = // rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), // allocaOp.getType().getElementType()); + auto sortedUsers = getSortedUsers(memVal); + + // If the first user is already a to_tensor op, don't try to debufferize + if (!sortedUsers.empty() && isa(sortedUsers[0])) { + return failure(); + } + auto toTensorOp = rewriter.create( memVal.getLoc(), tensorType, memVal); Value currentTensor = toTensorOp; - auto sortedUsers = getSortedUsers(memVal); //Other algorithm: // 1. Walk over all ops @@ -635,12 +641,22 @@ struct LinalgDebufferization : public OpRewritePattern { expandedUserList.insert(expandedUserList.begin() + userIdx, newGenericOp); userIdx++; expandedUserList.erase(expandedUserList.begin() + userIdx); + } else if (auto subviewOp = dyn_cast(user)) { - rewriter.setInsertionPointAfter(subviewOp); - auto newSubviewOp = rewriter.create( - subviewOp.getLoc(), subviewOp.getType(), subviewOp.getSource(), subviewOp.getOffsets(), subviewOp.getSizes(), subviewOp.getStrides()); - rewriter.replaceOp(subviewOp, newSubviewOp.getResult()); + if (subviewOp.getSource() == memVal) { + // Convert memref.subview to tensor.extract_slice + rewriter.setInsertionPointAfter(subviewOp); + auto extractSliceOp = rewriter.create( + subviewOp.getLoc(), + currentTensor, // Use the tensor version + subviewOp.getOffsets(), + subviewOp.getSizes(), + subviewOp.getStrides()); + + // This creates a new tensor that can be used by subsequent operations + // Need to handle this tensor in the debufferization chain + } } } @@ -662,17 +678,21 @@ struct LinalgDebufferization : public OpRewritePattern { // rewriter.setInsertionPointAfter(parentOp); //} //if(currentTensor != prevTensor) { - rewriter.setInsertionPointAfter(currentTensor.getDefiningOp()); - auto toMemrefOp = rewriter.create( - memVal.getLoc(), memrefType, currentTensor); - rewriter.create(memVal.getLoc(), toMemrefOp, memVal); + + // Only insert to_memref and copy if currentTensor was actually transformed + if (currentTensor != toTensorOp) { + rewriter.setInsertionPointAfter(currentTensor.getDefiningOp()); + auto toMemrefOp = rewriter.create( + memVal.getLoc(), memrefType, currentTensor); + rewriter.create(memVal.getLoc(), toMemrefOp, memVal); + } //} // opsToDelete.push_back(allocaOp.getOperation()); return success(); }; - bool changed; + bool anySuccess = false; //Fix instead of walk, just get the list of allocaOp users, so that you can easily delete ops inside SmallVector listOfAllocaOps; SmallVector listOfAllocOps; @@ -686,18 +706,18 @@ struct LinalgDebufferization : public OpRewritePattern { }); for (auto alloca : listOfAllocaOps) { - handleMemref(alloca); + anySuccess |= succeeded(handleMemref(alloca)); } for (auto alloc : listOfAllocOps) { - handleMemref(alloc); + anySuccess |= succeeded(handleMemref(alloc)); } for(auto arg: funcOp.getArguments()){ - handleMemref(arg); + anySuccess |= succeeded(handleMemref(arg)); } - passResult = success(); + passResult = anySuccess ? success() : failure(); //for (Operation *op : opsToDelete) { // op->erase(); //} From c3c27004489f1fc93a496c1d460c2c030be90e5b Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 1 Aug 2025 16:54:48 -0700 Subject: [PATCH 73/77] improved matcher to create a dependency graph and use it for matching --- lib/polygeist/Passes/LinalgToKernel.cpp | 223 +++++++++++++++++++++--- tools/polygeist-opt/polygeist-opt.cpp | 2 + 2 files changed, 205 insertions(+), 20 deletions(-) diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp index 420c985df71b..f891fac9cacd 100644 --- a/lib/polygeist/Passes/LinalgToKernel.cpp +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -7,6 +7,7 @@ #include "PassDetails.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" @@ -14,6 +15,9 @@ #include "mlir/Parser/Parser.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "polygeist/Kernel/KernelDialect.h" @@ -22,6 +26,7 @@ #include #include +#include using namespace mlir; using namespace mlir::linalg; @@ -30,36 +35,214 @@ using namespace mlir::polygeist::kernel; namespace { -// Helper function to check if two regions are structurally equivalent +// Structure to represent an operation node in the dependency graph +struct OpNode { + Operation *op; + StringRef opName; + SmallVector operandTypes; + SmallVector resultTypes; + SmallVector dependencies; // Operations this depends on + SmallVector dependents; // Operations that depend on this + + OpNode(Operation *operation) : op(operation) { + if (operation) { + // Regular operation node + opName = operation->getName().getStringRef(); + for (Value operand : operation->getOperands()) { + operandTypes.push_back(operand.getType()); + } + for (Value result : operation->getResults()) { + resultTypes.push_back(result.getType()); + } + } else { + // Special node for block arguments - will be set later + opName = "block_arg"; + } + } + + // Check if two nodes are structurally equivalent (same operation type and types) + bool isEquivalentTo(const OpNode &other) const { + return opName == other.opName && + operandTypes == other.operandTypes && + resultTypes == other.resultTypes; + } +}; + +// Structure to represent a dependency graph for a region +struct DependencyGraph { + SmallVector> nodes; + DenseMap opToNode; + SmallVector blockArgNodes; // Special nodes for block arguments + + void buildFromRegion(Region ®ion) { + // Process each block in the region + for (Block &block : region.getBlocks()) { + + // Create pseudo-nodes for block arguments + for (BlockArgument arg : block.getArguments()) { + // Block arguments are represented as special nodes + auto argNode = std::make_unique(nullptr); + argNode->resultTypes.push_back(arg.getType()); + blockArgNodes.push_back(argNode.get()); + + // Map the block argument value to this node for dependency tracking + // We'll use a separate map for this + nodes.push_back(std::move(argNode)); + } + + // Create nodes for each operation + for (Operation &op : block.getOperations()) { + auto node = std::make_unique(&op); + OpNode *nodePtr = node.get(); + opToNode[&op] = nodePtr; + nodes.push_back(std::move(node)); + } + + // Build dependency edges + for (Operation &op : block.getOperations()) { + OpNode *currentNode = opToNode[&op]; + + // For each operand, find what it depends on + for (Value operand : op.getOperands()) { + if (auto blockArg = dyn_cast(operand)) { + // Depends on a block argument + size_t argIndex = blockArg.getArgNumber(); + if (argIndex < blockArgNodes.size()) { + OpNode *argNode = blockArgNodes[argIndex]; + currentNode->dependencies.push_back(argNode); + argNode->dependents.push_back(currentNode); + } + } else if (Operation *definingOp = operand.getDefiningOp()) { + // Depends on another operation + if (opToNode.count(definingOp)) { + OpNode *depNode = opToNode[definingOp]; + currentNode->dependencies.push_back(depNode); + depNode->dependents.push_back(currentNode); + } + } + } + } + } + } + + // Get nodes in topological order (dependencies first) + SmallVector getTopologicalOrder() const { + SmallVector result; + DenseSet visited; + + std::function dfs = [&](OpNode* node) { + if (visited.contains(node)) return; + visited.insert(node); + + // Visit all dependencies first + for (OpNode* dep : node->dependencies) { + dfs(dep); + } + + result.push_back(node); + }; + + // Start DFS from all nodes + for (const auto &node : nodes) { + dfs(node.get()); + } + + return result; + } +}; + +// Enhanced region equivalence check using dependency graphs bool areRegionsEquivalent(Region &first, Region &second) { - // Compare number of blocks - if (first.getBlocks().size() != second.getBlocks().size()) + // Fast early checks before expensive graph construction + + // Check number of blocks + if (first.getBlocks().size() != second.getBlocks().size()) { return false; - - // Compare corresponding blocks + } + + // Check each block's basic properties for (auto blockPair : llvm::zip(first.getBlocks(), second.getBlocks())) { Block &firstBlock = std::get<0>(blockPair); Block &secondBlock = std::get<1>(blockPair); - - // Compare number of arguments - if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) + + // Check number of arguments + if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) { return false; - - // Compare argument types - for (auto argPair : llvm::zip(firstBlock.getArguments(), - secondBlock.getArguments())) { - if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) + } + + // Check argument types + for (auto argPair : llvm::zip(firstBlock.getArguments(), secondBlock.getArguments())) { + if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) { return false; + } } - - // Compare operations (simplified - real implementation would be more complex) - if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) + + // Check number of operations + if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) { return false; - - // For a full implementation, you'd need more sophisticated operation comparison - // based on operands, attributes, and result types + } } - + + // If basic checks pass, proceed with detailed graph-based analysis + // Build dependency graphs for both regions + DependencyGraph firstGraph, secondGraph; + firstGraph.buildFromRegion(first); + secondGraph.buildFromRegion(second); + + // Quick structural checks + if (firstGraph.nodes.size() != secondGraph.nodes.size()) { + return false; + } + + if (firstGraph.blockArgNodes.size() != secondGraph.blockArgNodes.size()) { + return false; + } + + // Get topological orderings + auto firstOrder = firstGraph.getTopologicalOrder(); + auto secondOrder = secondGraph.getTopologicalOrder(); + + if (firstOrder.size() != secondOrder.size()) { + return false; + } + + // Compare nodes in topological order + DenseMap nodeMapping; + + for (size_t i = 0; i < firstOrder.size(); ++i) { + OpNode *firstNode = firstOrder[i]; + OpNode *secondNode = secondOrder[i]; + + // Check if the nodes are structurally equivalent + if (!firstNode->isEquivalentTo(*secondNode)) { + return false; + } + + // Check if dependency structure matches + if (firstNode->dependencies.size() != secondNode->dependencies.size()) { + return false; + } + + // Verify that dependencies map correctly + for (size_t j = 0; j < firstNode->dependencies.size(); ++j) { + OpNode *firstDep = firstNode->dependencies[j]; + OpNode *secondDep = secondNode->dependencies[j]; + + // Check if we've established a mapping for these dependencies + auto it = nodeMapping.find(firstDep); + if (it != nodeMapping.end()) { + if (it->second != secondDep) { + return false; // Inconsistent mapping + } + } else { + nodeMapping[firstDep] = secondDep; + } + } + + // Establish mapping for current nodes + nodeMapping[firstNode] = secondNode; + } + return true; } diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 2a8eada21811..d653d835ab45 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -63,6 +64,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); From ca12291beccb6564bc28411a743d96cd71a3ca46 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Sun, 3 Aug 2025 16:16:55 -0700 Subject: [PATCH 74/77] Runtime failure but match happening correctly to kernel dialect --- generic_solver/kernel_library.mlir | 58 ++++++ include/polygeist/Passes/Passes.td | 1 + lib/polygeist/Passes/LinalgToKernel.cpp | 240 +++++++++++++++++++++--- 3 files changed, 277 insertions(+), 22 deletions(-) diff --git a/generic_solver/kernel_library.mlir b/generic_solver/kernel_library.mlir index dad0c3c7d68e..033f3958ecd8 100644 --- a/generic_solver/kernel_library.mlir +++ b/generic_solver/kernel_library.mlir @@ -48,6 +48,64 @@ module { kernel.yield %result : tensor } + // Alpha-scaled GEMM accumulation (matches the second operation in the user's pattern) + kernel.defn @alpha_gemm_accumulate(%A: tensor, %B: tensor, %C: tensor, %alpha: f64) -> tensor { + // Matrix multiplication with alpha scaling: C += alpha * A * B + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %6 = arith.mulf %alpha, %in : f64 + %7 = arith.mulf %6, %in_0 : f64 + %8 = arith.addf %out, %7 : f64 + linalg.yield %8 : f64 + } -> tensor + kernel.yield %result : tensor + } + + // Element-wise beta scaling (matches the first operation in the user's pattern) + kernel.defn @beta_scale(%C: tensor, %beta: f64) -> tensor { + // Element-wise scaling: C = beta * C + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d1, d0)> + ], + iterator_types = ["parallel", "parallel"] + } outs(%C : tensor) { + ^bb0(%out: f64): + %6 = arith.mulf %out, %beta : f64 + linalg.yield %6 : f64 + } -> tensor + kernel.yield %result : tensor + } + + // Matrix multiplication with alpha scaling (second operation standalone) + kernel.defn @gemm_alpha_only(%A: tensor, %B: tensor, %C: tensor, %alpha: f64) -> tensor { + // Matrix multiplication: C += alpha * A * B + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d2, d0)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %6 = arith.mulf %alpha, %in : f64 + %7 = arith.mulf %6, %in_0 : f64 + %8 = arith.addf %out, %7 : f64 + linalg.yield %8 : f64 + } -> tensor + kernel.yield %result : tensor + } + // Sum of absolute values operation (ASUM) kernel.defn @asum_linalg(%X: tensor, %init: tensor) -> tensor { // Sum of absolute values: result = sum_i |x_i| diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index eef142f6dbef..368eb59d28ab 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -303,6 +303,7 @@ def LinalgToKernel : Pass<"linalg-to-kernel", "mlir::ModuleOp"> { "polygeist::kernel::KernelDialect", "tensor::TensorDialect", "arith::ArithDialect", + "bufferization::BufferizationDialect", ]; let options = [ Option<"kernelLibraryPath", "kernel-library-path", "std::string", diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp index f891fac9cacd..5bf05e87d8fc 100644 --- a/lib/polygeist/Passes/LinalgToKernel.cpp +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -152,7 +152,12 @@ struct DependencyGraph { }; // Enhanced region equivalence check using dependency graphs -bool areRegionsEquivalent(Region &first, Region &second) { +bool areRegionsEquivalent(Region &first, Region &second, DenseMap &nodeMapping, + DenseMap &operationMapping) { + // Clear the output mappings + nodeMapping.clear(); + operationMapping.clear(); + // Fast early checks before expensive graph construction // Check number of blocks @@ -206,9 +211,7 @@ bool areRegionsEquivalent(Region &first, Region &second) { return false; } - // Compare nodes in topological order - DenseMap nodeMapping; - + // Compare nodes in topological order and build mapping for (size_t i = 0; i < firstOrder.size(); ++i) { OpNode *firstNode = firstOrder[i]; OpNode *secondNode = secondOrder[i]; @@ -241,8 +244,13 @@ bool areRegionsEquivalent(Region &first, Region &second) { // Establish mapping for current nodes nodeMapping[firstNode] = secondNode; + + // Build the operation mapping directly from OpNode data while still valid + if (firstNode->op && secondNode->op) { + operationMapping[firstNode->op] = secondNode->op; + } } - + return true; } @@ -278,8 +286,153 @@ bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { return true; } +// Helper function to find the corresponding value in actual IR for a kernel block argument +Value findCorrespondingValue(BlockArgument kernelArg, + const DenseMap &operationMapping, + GenericOp genericOp) { + + llvm::errs() << "DEBUG: Finding corresponding value for kernel arg #" << kernelArg.getArgNumber() + << " with type " << kernelArg.getType() << "\n"; + + // First, check if this kernel argument is used as an operand to the linalg.generic itself + // This handles block arguments that become ins/outs operands + for (Operation *kernelUser : kernelArg.getUsers()) { + llvm::errs() << "DEBUG: Kernel arg used by: " << *kernelUser << "\n"; + + // Check if the user is a linalg.generic operation + if (auto kernelGeneric = dyn_cast(kernelUser)) { + llvm::errs() << "DEBUG: Kernel arg is used by linalg.generic as operand\n"; + + // Find which operand position kernelArg occupies in the kernel's linalg.generic + size_t operandIndex = 0; + for (Value operand : kernelGeneric->getOperands()) { + if (operand == kernelArg) { + llvm::errs() << "DEBUG: Kernel arg is at operand index " << operandIndex + << " of kernel linalg.generic\n"; + + // The corresponding operand in the actual linalg.generic should be at the same position + if (operandIndex < genericOp->getNumOperands()) { + Value actualOperand = genericOp->getOperand(operandIndex); + llvm::errs() << "DEBUG: Found corresponding actual operand: " << actualOperand << "\n"; + return actualOperand; + } else { + llvm::errs() << "DEBUG: ERROR - operand index out of bounds in actual generic\n"; + } + break; + } + operandIndex++; + } + } else { + // This is the original logic for operations inside the region + // Find the corresponding operation in actual IR using reverse mapping + auto it = std::find_if(operationMapping.begin(), operationMapping.end(), + [kernelUser](const auto& pair) { + return pair.second == kernelUser; + }); + + if (it != operationMapping.end()) { + Operation *actualUser = it->first; // The actual IR operation + llvm::errs() << "DEBUG: Found corresponding actual operation: " << *actualUser << "\n"; + + // Find which operand position kernelArg occupies in kernelUser + size_t operandIndex = 0; + for (Value operand : kernelUser->getOperands()) { + if (operand == kernelArg) { + break; + } + operandIndex++; + } + + llvm::errs() << "DEBUG: Kernel arg is at operand index " << operandIndex << "\n"; + + // Ensure we don't go out of bounds + if (operandIndex < actualUser->getNumOperands()) { + // Get the corresponding operand from actual IR + Value actualOperand = actualUser->getOperand(operandIndex); + llvm::errs() << "DEBUG: Found corresponding actual operand: " << actualOperand << "\n"; + return actualOperand; + } else { + llvm::errs() << "DEBUG: ERROR - operand index out of bounds\n"; + } + } else { + llvm::errs() << "DEBUG: Could not find corresponding operation in operationMapping\n"; + } + } + } + + // If we reach here, this might be a scalar argument used inside the region + // For scalar arguments like %arg3, %arg4, use operation mapping to trace usage + llvm::errs() << "DEBUG: Checking if kernel arg is a scalar used inside region\n"; + + for (Operation *kernelUser : kernelArg.getUsers()) { + // Skip if this is the linalg.generic itself (already handled above) + if (isa(kernelUser)) continue; + + llvm::errs() << "DEBUG: Kernel arg used by operation: " << *kernelUser << "\n"; + + // Find the corresponding operation in actual IR using the fixed mapping + // Note: operationMapping is actualOp -> kernelOp, so we need to reverse-search + auto it = std::find_if(operationMapping.begin(), operationMapping.end(), + [kernelUser](const auto& pair) { + return pair.second == kernelUser; + }); + if (it != operationMapping.end()) { + Operation *actualUser = it->first; // The actual IR operation + llvm::errs() << "DEBUG: Found corresponding actual operation: " << *actualUser << "\n"; + + // Find which operand position kernelArg occupies in kernelUser + size_t operandIndex = 0; + for (Value operand : kernelUser->getOperands()) { + if (operand == kernelArg) { + llvm::errs() << "DEBUG: Kernel arg is at operand index " << operandIndex << "\n"; + + // Get the corresponding operand from actual IR + if (operandIndex < actualUser->getNumOperands()) { + Value actualOperand = actualUser->getOperand(operandIndex); + llvm::errs() << "DEBUG: Found corresponding actual operand: " << actualOperand << "\n"; + return actualOperand; + } else { + llvm::errs() << "DEBUG: ERROR - operand index out of bounds\n"; + } + break; + } + operandIndex++; + } + } else { + llvm::errs() << "DEBUG: Could not find corresponding operation in operationMapping\n"; + } + } + + // Fallback: if operation mapping fails, try type matching as last resort + llvm::errs() << "DEBUG: Fallback to type matching for function arguments\n"; + + auto func = genericOp->getParentOfType(); + if (func) { + llvm::errs() << "DEBUG: Found parent function with " << func.getNumArguments() << " arguments\n"; + + // Look for function arguments with matching type + for (auto funcArg : func.getArguments()) { + if (funcArg.getType() == kernelArg.getType()) { + llvm::errs() << "DEBUG: Found function argument with matching type: " << funcArg << "\n"; + // TODO: This is still not ideal - should be improved with better analysis + return funcArg; + } + } + } + + llvm::errs() << "DEBUG: ERROR - Could not find corresponding value for kernel arg\n"; + return nullptr; +} + +// Structure to hold the result of matching a generic operation with a kernel definition +struct KernelMatchResult { + StringRef kernelName; + DenseMap operationMapping; // actual op -> kernel op + kernel::DefnOp matchedDefnOp; +}; + // Check if a linalg.generic operation matches a kernel.defn in a collection -FailureOr matchGenericWithDefn( +FailureOr matchGenericWithDefn( GenericOp genericOp, kernel::DefnCollectionOp collectionOp) { @@ -291,6 +444,8 @@ FailureOr matchGenericWithDefn( // Variables to capture the match result StringRef matchedOpName; + DenseMap matchedOperationMapping; + kernel::DefnOp matchedDefnOp; SmallVector defnOps; @@ -308,6 +463,8 @@ FailureOr matchGenericWithDefn( for (auto defnOp : defnOps) { StringRef opName = defnOp.getSymName(); + llvm::errs() << "DEBUG: Checking kernel defn: " << opName << "\n"; + // Check for linalg.generic in the defn's body GenericOp candidateOp; @@ -316,21 +473,43 @@ FailureOr matchGenericWithDefn( }); if(!candidateOp) { + llvm::errs() << "DEBUG: No linalg.generic found in defn " << opName << "\n"; continue; } + llvm::errs() << "DEBUG: Found linalg.generic in defn " << opName << "\n"; + llvm::errs() << "DEBUG: Candidate numInputs=" << candidateOp.getNumDpsInputs() + << ", target numInputs=" << numInputs << "\n"; + llvm::errs() << "DEBUG: Candidate numOutputs=" << candidateOp.getNumDpsInits() + << ", target numOutputs=" << numOutputs << "\n"; + // Check if this linalg.generic matches our target + DenseMap nodeMapping; + DenseMap operationMapping; // Added for findCorrespondingValue if (candidateOp.getNumDpsInputs() == numInputs && candidateOp.getNumDpsInits() == numOutputs && areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && - areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) { + areRegionsEquivalent(genericOp.getRegion(), candidateOp.getRegion(), nodeMapping, operationMapping)) { + llvm::errs() << "DEBUG: MATCH FOUND for defn " << opName << "\n"; foundMatch = true; matchedOpName = opName; + matchedOperationMapping = operationMapping; // Store the mapping + matchedDefnOp = defnOp; // Store the matched defnOp + } else { + llvm::errs() << "DEBUG: No match for defn " << opName << "\n"; + llvm::errs() << "DEBUG: Input/output check: " + << (candidateOp.getNumDpsInputs() == numInputs) << "\n"; + llvm::errs() << "DEBUG: Maps check: " + << areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) << "\n"; + llvm::errs() << "DEBUG: Iterator types check: " + << areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) << "\n"; + llvm::errs() << "DEBUG: Regions check: " + << areRegionsEquivalent(genericOp.getRegion(), candidateOp.getRegion(), nodeMapping, operationMapping) << "\n"; } if (foundMatch) { - return matchedOpName; + return KernelMatchResult{matchedOpName, matchedOperationMapping, matchedDefnOp}; } } @@ -347,31 +526,30 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { + llvm::errs() << "DEBUG: matchAndRewrite called for genericOp:\n"; + llvm::errs() << genericOp << "\n"; + auto module = genericOp->getParentOfType(); //Check if the parent of the generic op is a kernel.defn if (auto parentOp = genericOp->getParentOp()) { if (isa(parentOp)) { + llvm::errs() << "DEBUG: Skipping genericOp inside kernel.defn\n"; return failure(); } } // Try to match with a defn in the collection auto matchResult = matchGenericWithDefn(genericOp, collectionOp); - if (failed(matchResult)) + if (failed(matchResult)) { + llvm::errs() << "DEBUG: No match found in collection\n"; return failure(); + } - StringRef opName = *matchResult; + StringRef opName = matchResult->kernelName; + llvm::errs() << "DEBUG: Match found with kernel: " << opName << "\n"; // Find the matched kernel.defn operation - kernel::DefnOp matchedDefnOp; - // Use const_cast to work around the const issue - const_cast(collectionOp).walk([&](kernel::DefnOp defnOp) { - if (defnOp.getSymName() == opName) { - matchedDefnOp = defnOp; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); + kernel::DefnOp matchedDefnOp = matchResult->matchedDefnOp; if (!matchedDefnOp) { return failure(); @@ -404,10 +582,28 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { // Set insertion point to the genericOp location rewriter.setInsertionPoint(genericOp); - // Get operands from the generic operation (inputs and outputs) + // Get the kernel function signature to map all arguments + Block &kernelBlock = matchedDefnOp.getRegion().front(); + auto kernelArgs = kernelBlock.getArguments(); + + // Use the operationMapping from the match result (no need to call areRegionsEquivalent again) + const DenseMap &operationMapping = matchResult->operationMapping; + + // Use unified approach: map ALL kernel arguments to their corresponding actual values SmallVector operands; - operands.append(genericOp.getInputs().begin(), genericOp.getInputs().end()); - operands.append(genericOp.getOutputs().begin(), genericOp.getOutputs().end()); + llvm::errs() << "DEBUG: Starting to map " << kernelArgs.size() << " kernel arguments\n"; + + for (BlockArgument kernelArg : kernelArgs) { + Value actualValue = findCorrespondingValue(kernelArg, operationMapping, genericOp); + if (!actualValue) { + llvm::errs() << "DEBUG: Failed to find corresponding value for kernel arg #" + << kernelArg.getArgNumber() << " - returning failure\n"; + return failure(); // Could not find corresponding value + } + operands.push_back(actualValue); + } + + llvm::errs() << "DEBUG: Successfully mapped all kernel arguments, creating kernel.launch\n"; // Get result types from the generic operation TypeRange resultTypes = genericOp.getResultTypes(); From 7c204f28a5ae17bd9787059a25ddd43db7635ddf Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Sun, 3 Aug 2025 16:22:32 -0700 Subject: [PATCH 75/77] Working match for linalg kernel match for gemm --- lib/polygeist/Passes/LinalgToKernel.cpp | 60 ++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp index 5bf05e87d8fc..929013ce10de 100644 --- a/lib/polygeist/Passes/LinalgToKernel.cpp +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -605,19 +605,67 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { llvm::errs() << "DEBUG: Successfully mapped all kernel arguments, creating kernel.launch\n"; + // Get kernel function signature types for casting + auto kernelFuncType = matchedDefnOp.getFunctionType(); + auto kernelInputTypes = kernelFuncType.getInputs(); + auto kernelResultTypes = kernelFuncType.getResults(); + + // Cast operands to match kernel signature types if needed + SmallVector castedOperands; + for (size_t i = 0; i < operands.size(); ++i) { + Value operand = operands[i]; + Type expectedType = (i < kernelInputTypes.size()) ? kernelInputTypes[i] : operand.getType(); + + if (operand.getType() != expectedType) { + // Insert tensor.cast for type conversion + if (isa(operand.getType()) && isa(expectedType)) { + llvm::errs() << "DEBUG: Casting operand " << i << " from " << operand.getType() + << " to " << expectedType << "\n"; + auto castOp = rewriter.create(loc, expectedType, operand); + castedOperands.push_back(castOp.getResult()); + } else { + // For non-tensor types, use the operand as-is + castedOperands.push_back(operand); + } + } else { + castedOperands.push_back(operand); + } + } + // Get result types from the generic operation - TypeRange resultTypes = genericOp.getResultTypes(); + TypeRange originalResultTypes = genericOp.getResultTypes(); - // Create the kernel.launch operation + // Create the kernel.launch operation with casted operands and kernel result types auto launchOp = rewriter.create( loc, - resultTypes, + kernelResultTypes, // Use kernel result types for the launch op opName, - operands + castedOperands // Use casted operands ); - // Replace the generic operation with the launch operation - rewriter.replaceOp(genericOp, launchOp.getResults()); + // Cast results back to original types if needed + SmallVector finalResults; + for (size_t i = 0; i < launchOp.getResults().size(); ++i) { + Value result = launchOp.getResult(i); + Type originalType = (i < originalResultTypes.size()) ? originalResultTypes[i] : result.getType(); + + if (result.getType() != originalType) { + // Insert tensor.cast to convert back to original type + if (isa(result.getType()) && isa(originalType)) { + llvm::errs() << "DEBUG: Casting result " << i << " from " << result.getType() + << " to " << originalType << "\n"; + auto castOp = rewriter.create(loc, originalType, result); + finalResults.push_back(castOp.getResult()); + } else { + finalResults.push_back(result); + } + } else { + finalResults.push_back(result); + } + } + + // Replace the generic operation with the final results + rewriter.replaceOp(genericOp, finalResults); return success(); } From 37dd847dcb55c0a0991df4c0b9591acaa80253c4 Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Sun, 3 Aug 2025 16:32:30 -0700 Subject: [PATCH 76/77] Added debug prints --- lib/polygeist/Passes/LinalgToKernel.cpp | 143 ++++++++++-------------- 1 file changed, 57 insertions(+), 86 deletions(-) diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp index 929013ce10de..3563c0ae4731 100644 --- a/lib/polygeist/Passes/LinalgToKernel.cpp +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/Debug.h" #include "polygeist/Kernel/KernelDialect.h" #include "polygeist/Kernel/KernelOps.h" #include "polygeist/Passes/Passes.h" @@ -28,6 +29,8 @@ #include #include +#define DEBUG_TYPE "linalg-to-kernel" + using namespace mlir; using namespace mlir::linalg; using namespace mlir::polygeist; @@ -291,84 +294,52 @@ Value findCorrespondingValue(BlockArgument kernelArg, const DenseMap &operationMapping, GenericOp genericOp) { - llvm::errs() << "DEBUG: Finding corresponding value for kernel arg #" << kernelArg.getArgNumber() - << " with type " << kernelArg.getType() << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Finding corresponding value for kernel arg #" << kernelArg.getArgNumber() + << " with type " << kernelArg.getType() << "\n"); // First, check if this kernel argument is used as an operand to the linalg.generic itself - // This handles block arguments that become ins/outs operands + // This handles tensor arguments that become ins/outs operands for (Operation *kernelUser : kernelArg.getUsers()) { - llvm::errs() << "DEBUG: Kernel arg used by: " << *kernelUser << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Kernel arg used by: " << *kernelUser << "\n"); // Check if the user is a linalg.generic operation if (auto kernelGeneric = dyn_cast(kernelUser)) { - llvm::errs() << "DEBUG: Kernel arg is used by linalg.generic as operand\n"; + LLVM_DEBUG(llvm::dbgs() << "Kernel arg is used by linalg.generic as operand\n"); // Find which operand position kernelArg occupies in the kernel's linalg.generic size_t operandIndex = 0; for (Value operand : kernelGeneric->getOperands()) { if (operand == kernelArg) { - llvm::errs() << "DEBUG: Kernel arg is at operand index " << operandIndex - << " of kernel linalg.generic\n"; + LLVM_DEBUG(llvm::dbgs() << "Kernel arg is at operand index " << operandIndex + << " of kernel linalg.generic\n"); // The corresponding operand in the actual linalg.generic should be at the same position if (operandIndex < genericOp->getNumOperands()) { Value actualOperand = genericOp->getOperand(operandIndex); - llvm::errs() << "DEBUG: Found corresponding actual operand: " << actualOperand << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Found corresponding actual operand: " << actualOperand << "\n"); return actualOperand; } else { - llvm::errs() << "DEBUG: ERROR - operand index out of bounds in actual generic\n"; + LLVM_DEBUG(llvm::dbgs() << "ERROR - operand index out of bounds in actual generic\n"); } break; } operandIndex++; } - } else { - // This is the original logic for operations inside the region - // Find the corresponding operation in actual IR using reverse mapping - auto it = std::find_if(operationMapping.begin(), operationMapping.end(), - [kernelUser](const auto& pair) { - return pair.second == kernelUser; - }); - if (it != operationMapping.end()) { - Operation *actualUser = it->first; // The actual IR operation - llvm::errs() << "DEBUG: Found corresponding actual operation: " << *actualUser << "\n"; - - // Find which operand position kernelArg occupies in kernelUser - size_t operandIndex = 0; - for (Value operand : kernelUser->getOperands()) { - if (operand == kernelArg) { - break; - } - operandIndex++; - } - - llvm::errs() << "DEBUG: Kernel arg is at operand index " << operandIndex << "\n"; - - // Ensure we don't go out of bounds - if (operandIndex < actualUser->getNumOperands()) { - // Get the corresponding operand from actual IR - Value actualOperand = actualUser->getOperand(operandIndex); - llvm::errs() << "DEBUG: Found corresponding actual operand: " << actualOperand << "\n"; - return actualOperand; - } else { - llvm::errs() << "DEBUG: ERROR - operand index out of bounds\n"; - } - } else { - llvm::errs() << "DEBUG: Could not find corresponding operation in operationMapping\n"; - } + // If we found a linalg.generic usage, we're done with this user + break; } } // If we reach here, this might be a scalar argument used inside the region // For scalar arguments like %arg3, %arg4, use operation mapping to trace usage - llvm::errs() << "DEBUG: Checking if kernel arg is a scalar used inside region\n"; + LLVM_DEBUG(llvm::dbgs() << "Checking if kernel arg is a scalar used inside region\n"); for (Operation *kernelUser : kernelArg.getUsers()) { // Skip if this is the linalg.generic itself (already handled above) if (isa(kernelUser)) continue; - llvm::errs() << "DEBUG: Kernel arg used by operation: " << *kernelUser << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Kernel arg used by operation: " << *kernelUser << "\n"); // Find the corresponding operation in actual IR using the fixed mapping // Note: operationMapping is actualOp -> kernelOp, so we need to reverse-search @@ -378,49 +349,49 @@ Value findCorrespondingValue(BlockArgument kernelArg, }); if (it != operationMapping.end()) { Operation *actualUser = it->first; // The actual IR operation - llvm::errs() << "DEBUG: Found corresponding actual operation: " << *actualUser << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Found corresponding actual operation: " << *actualUser << "\n"); // Find which operand position kernelArg occupies in kernelUser size_t operandIndex = 0; for (Value operand : kernelUser->getOperands()) { if (operand == kernelArg) { - llvm::errs() << "DEBUG: Kernel arg is at operand index " << operandIndex << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Kernel arg is at operand index " << operandIndex << "\n"); // Get the corresponding operand from actual IR if (operandIndex < actualUser->getNumOperands()) { Value actualOperand = actualUser->getOperand(operandIndex); - llvm::errs() << "DEBUG: Found corresponding actual operand: " << actualOperand << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Found corresponding actual operand: " << actualOperand << "\n"); return actualOperand; } else { - llvm::errs() << "DEBUG: ERROR - operand index out of bounds\n"; + LLVM_DEBUG(llvm::dbgs() << "ERROR - operand index out of bounds\n"); } break; } operandIndex++; } } else { - llvm::errs() << "DEBUG: Could not find corresponding operation in operationMapping\n"; + LLVM_DEBUG(llvm::dbgs() << "Could not find corresponding operation in operationMapping\n"); } } // Fallback: if operation mapping fails, try type matching as last resort - llvm::errs() << "DEBUG: Fallback to type matching for function arguments\n"; + LLVM_DEBUG(llvm::dbgs() << "Fallback to type matching for function arguments\n"); auto func = genericOp->getParentOfType(); if (func) { - llvm::errs() << "DEBUG: Found parent function with " << func.getNumArguments() << " arguments\n"; + LLVM_DEBUG(llvm::dbgs() << "Found parent function with " << func.getNumArguments() << " arguments\n"); // Look for function arguments with matching type for (auto funcArg : func.getArguments()) { if (funcArg.getType() == kernelArg.getType()) { - llvm::errs() << "DEBUG: Found function argument with matching type: " << funcArg << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Found function argument with matching type: " << funcArg << "\n"); // TODO: This is still not ideal - should be improved with better analysis return funcArg; } } } - llvm::errs() << "DEBUG: ERROR - Could not find corresponding value for kernel arg\n"; + LLVM_DEBUG(llvm::dbgs() << "ERROR - Could not find corresponding value for kernel arg\n"); return nullptr; } @@ -463,7 +434,7 @@ FailureOr matchGenericWithDefn( for (auto defnOp : defnOps) { StringRef opName = defnOp.getSymName(); - llvm::errs() << "DEBUG: Checking kernel defn: " << opName << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Checking kernel defn: " << opName << "\n"); // Check for linalg.generic in the defn's body GenericOp candidateOp; @@ -473,15 +444,15 @@ FailureOr matchGenericWithDefn( }); if(!candidateOp) { - llvm::errs() << "DEBUG: No linalg.generic found in defn " << opName << "\n"; + LLVM_DEBUG(llvm::dbgs() << "No linalg.generic found in defn " << opName << "\n"); continue; } - llvm::errs() << "DEBUG: Found linalg.generic in defn " << opName << "\n"; - llvm::errs() << "DEBUG: Candidate numInputs=" << candidateOp.getNumDpsInputs() - << ", target numInputs=" << numInputs << "\n"; - llvm::errs() << "DEBUG: Candidate numOutputs=" << candidateOp.getNumDpsInits() - << ", target numOutputs=" << numOutputs << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Found linalg.generic in defn " << opName << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Candidate numInputs=" << candidateOp.getNumDpsInputs() + << ", target numInputs=" << numInputs << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Candidate numOutputs=" << candidateOp.getNumDpsInits() + << ", target numOutputs=" << numOutputs << "\n"); // Check if this linalg.generic matches our target DenseMap nodeMapping; @@ -491,21 +462,21 @@ FailureOr matchGenericWithDefn( areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && areRegionsEquivalent(genericOp.getRegion(), candidateOp.getRegion(), nodeMapping, operationMapping)) { - llvm::errs() << "DEBUG: MATCH FOUND for defn " << opName << "\n"; + LLVM_DEBUG(llvm::dbgs() << "MATCH FOUND for defn " << opName << "\n"); foundMatch = true; matchedOpName = opName; - matchedOperationMapping = operationMapping; // Store the mapping + matchedOperationMapping = operationMapping; // Store the operation mapping matchedDefnOp = defnOp; // Store the matched defnOp } else { - llvm::errs() << "DEBUG: No match for defn " << opName << "\n"; - llvm::errs() << "DEBUG: Input/output check: " - << (candidateOp.getNumDpsInputs() == numInputs) << "\n"; - llvm::errs() << "DEBUG: Maps check: " - << areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) << "\n"; - llvm::errs() << "DEBUG: Iterator types check: " - << areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) << "\n"; - llvm::errs() << "DEBUG: Regions check: " - << areRegionsEquivalent(genericOp.getRegion(), candidateOp.getRegion(), nodeMapping, operationMapping) << "\n"; + LLVM_DEBUG(llvm::dbgs() << "No match for defn " << opName << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Input/output check: " + << (candidateOp.getNumDpsInputs() == numInputs) << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Maps check: " + << areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Iterator types check: " + << areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Regions check: " + << areRegionsEquivalent(genericOp.getRegion(), candidateOp.getRegion(), nodeMapping, operationMapping) << "\n"); } if (foundMatch) { @@ -526,14 +497,14 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - llvm::errs() << "DEBUG: matchAndRewrite called for genericOp:\n"; - llvm::errs() << genericOp << "\n"; + LLVM_DEBUG(llvm::dbgs() << "matchAndRewrite called for genericOp:\n"); + LLVM_DEBUG(llvm::dbgs() << genericOp << "\n"); auto module = genericOp->getParentOfType(); //Check if the parent of the generic op is a kernel.defn if (auto parentOp = genericOp->getParentOp()) { if (isa(parentOp)) { - llvm::errs() << "DEBUG: Skipping genericOp inside kernel.defn\n"; + LLVM_DEBUG(llvm::dbgs() << "Skipping genericOp inside kernel.defn\n"); return failure(); } } @@ -541,12 +512,12 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { // Try to match with a defn in the collection auto matchResult = matchGenericWithDefn(genericOp, collectionOp); if (failed(matchResult)) { - llvm::errs() << "DEBUG: No match found in collection\n"; + LLVM_DEBUG(llvm::dbgs() << "No match found in collection\n"); return failure(); } StringRef opName = matchResult->kernelName; - llvm::errs() << "DEBUG: Match found with kernel: " << opName << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Match found with kernel: " << opName << "\n"); // Find the matched kernel.defn operation kernel::DefnOp matchedDefnOp = matchResult->matchedDefnOp; @@ -591,19 +562,19 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { // Use unified approach: map ALL kernel arguments to their corresponding actual values SmallVector operands; - llvm::errs() << "DEBUG: Starting to map " << kernelArgs.size() << " kernel arguments\n"; + LLVM_DEBUG(llvm::dbgs() << "Starting to map " << kernelArgs.size() << " kernel arguments\n"); for (BlockArgument kernelArg : kernelArgs) { Value actualValue = findCorrespondingValue(kernelArg, operationMapping, genericOp); if (!actualValue) { - llvm::errs() << "DEBUG: Failed to find corresponding value for kernel arg #" - << kernelArg.getArgNumber() << " - returning failure\n"; + LLVM_DEBUG(llvm::dbgs() << "Failed to find corresponding value for kernel arg #" + << kernelArg.getArgNumber() << " - returning failure\n"); return failure(); // Could not find corresponding value } operands.push_back(actualValue); } - llvm::errs() << "DEBUG: Successfully mapped all kernel arguments, creating kernel.launch\n"; + LLVM_DEBUG(llvm::dbgs() << "Successfully mapped all kernel arguments, creating kernel.launch\n"); // Get kernel function signature types for casting auto kernelFuncType = matchedDefnOp.getFunctionType(); @@ -619,8 +590,8 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { if (operand.getType() != expectedType) { // Insert tensor.cast for type conversion if (isa(operand.getType()) && isa(expectedType)) { - llvm::errs() << "DEBUG: Casting operand " << i << " from " << operand.getType() - << " to " << expectedType << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Casting operand " << i << " from " << operand.getType() + << " to " << expectedType << "\n"); auto castOp = rewriter.create(loc, expectedType, operand); castedOperands.push_back(castOp.getResult()); } else { @@ -652,8 +623,8 @@ class LinalgGenericToKernelPattern : public OpRewritePattern { if (result.getType() != originalType) { // Insert tensor.cast to convert back to original type if (isa(result.getType()) && isa(originalType)) { - llvm::errs() << "DEBUG: Casting result " << i << " from " << result.getType() - << " to " << originalType << "\n"; + LLVM_DEBUG(llvm::dbgs() << "Casting result " << i << " from " << result.getType() + << " to " << originalType << "\n"); auto castOp = rewriter.create(loc, originalType, result); finalResults.push_back(castOp.getResult()); } else { @@ -729,7 +700,7 @@ struct LinalgToKernelPass : public LinalgToKernelBase { // Find the kernel.defn_collection in the external module externalModule->walk([&](kernel::DefnCollectionOp op) { collectionOp = op; - llvm::errs() << "DEBUG: Found kernel.defn_collection in external module\n"; + LLVM_DEBUG(llvm::dbgs() << "Found kernel.defn_collection in external module\n"); return WalkResult::interrupt(); }); From 7e3f0d02cfe5ed3289f544dc22b8195fee5a14fd Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Sun, 3 Aug 2025 21:00:15 -0700 Subject: [PATCH 77/77] Able to raise gemv --- generic_solver/kernel_library.mlir | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/generic_solver/kernel_library.mlir b/generic_solver/kernel_library.mlir index 033f3958ecd8..fd4fd6a48a70 100644 --- a/generic_solver/kernel_library.mlir +++ b/generic_solver/kernel_library.mlir @@ -170,6 +170,26 @@ module { kernel.yield %result : tensor } + // General Matrix-Vector Multiply (GEMV) + kernel.defn @gemv_simple(%A: tensor, %x: tensor, %y: tensor) -> tensor { + // Simple matrix-vector multiplication: y += A * x + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d1, d0)>, // Matrix A[d0, d1] + affine_map<(d0, d1) -> (d0)>, // Vector x[d1] + affine_map<(d0, d1) -> (d1)> // Vector y[d0] + ], + iterator_types = ["parallel", "reduction"] + } ins(%A, %x : tensor, tensor) + outs(%y : tensor) { + ^bb0(%a: f64, %x_val: f64, %y_val: f64): + %product = arith.mulf %a, %x_val : f64 + %result = arith.addf %y_val, %product : f64 + linalg.yield %result : f64 + } -> tensor + kernel.yield %result : tensor + } + // Index of minimum absolute value operation definition with linalg.generic representation kernel.defn @iamin_linalg(%X: tensor, %init: tensor) -> tensor { // Implementation using linalg.generic