diff --git a/generic_solver/CublasDefnPattern.cpp b/generic_solver/CublasDefnPattern.cpp new file mode 100644 index 000000000000..4a62fb8345da --- /dev/null +++ b/generic_solver/CublasDefnPattern.cpp @@ -0,0 +1,360 @@ +//===- KernelDefnPattern.cpp - Pattern to match linalg.generic with kernel.defn ------===// +// +// This file implements a pattern to rewrite linalg.generic operations to kernel +// operations by matching against patterns defined in kernel.defn_collection. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "KernelOps.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +// Cases: +// 1. What if they do a*(b+c) as a*b+a*c ? +// 2. What is they do (a+b)/c as a/c+b/c ? +// - The required best form can vary based on a cost model for a given architecture +// - The expectation is that kernel.defn is the best form an op is expected to take +// - The generic solver will employ heuristics to match the best form +// - Heuristics can be as simple as "is the op a commutative operation ?", +// "is the op an associative operation ?", "is the op distributive ?", etc. +// 3. What if the order of operations is different ? add(a,b) as add(b,a) +// - This requires a commutative check for operations, i.e in commutative ops +// we don't need to match positions +// 4. What if order of uses are different for an op? Eg- +// a1 = ... | a2 = ... +// b1 = a1/c1 | d2 = a2*c2 +// d1 = a1*c1 | b2 = a2/c2 +// - In this case, we need to find the corresponding uses of the operands +// 5. + +// Non-recursive traversal of use-def chain using a stack +bool compareUseDefChains(Value firstValue, Value secondValue) { + // Use a std::stack to track operations we need to visit + std::stack> workList; + std::set> visited; + + // Start with the initial values + workList.push({firstValue, secondValue}); + + while (!workList.empty()) { + auto [value1, value2] = workList.top(); + workList.pop(); + + // Skip if we've already processed this pair + auto valuePtrPair = std::make_pair(value1.getImpl(), value2.getImpl()); + if (visited.count(valuePtrPair)) + continue; + visited.insert(valuePtrPair); + + // Compare the values themselves + if (value1.getType() != value2.getType()) + return false; + + // Compare all uses + auto uses1 = value1.getUses(); + auto uses2 = value2.getUses(); + + // Process each use + for (auto &use1 : uses1) { + Operation *op1 = use1.getOwner(); + + // Find corresponding use in second value + bool foundMatch = false; + for (auto &use2 : uses2) { + Operation *op2 = use2.getOwner(); + + // Compare operations (customize based on your definition of equivalence) + if (op1->getName() == op2->getName() && + //This requires a commutative check + use1.getOperandNumber() == use2.getOperandNumber()) { + foundMatch = true; + + // Add results to worklist to continue traversal + for (unsigned i = 0; i < op1->getNumResults(); ++i) { + if (i < op2->getNumResults()) + workList.push({op1->getResult(i), op2->getResult(i)}); + } + break; + } + } + + if (!foundMatch) + return false; + } + } + + return true; +} + + +// Helper function to check if two regions are structurally equivalent +bool areRegionsEquivalent(Region &first, Region &second) { + // Compare number of blocks + if (first.getBlocks().size() != second.getBlocks().size()) + return false; + + // Compare corresponding blocks + for (auto blockPair : llvm::zip(first.getBlocks(), second.getBlocks())) { + Block &firstBlock = std::get<0>(blockPair); + Block &secondBlock = std::get<1>(blockPair); + + // Compare number of arguments + if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) + return false; + + //// Compare argument types + //for (auto argPair : llvm::zip(firstBlock.getArguments(), + // secondBlock.getArguments())) { + // if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) + // return false; + //} + + //Traverse the use-def chain of the arguments and compare the operation names + for (auto argPair : llvm::zip(firstBlock.getArguments(), + secondBlock.getArguments())) { + if (std::get<0>(argPair).getName() != std::get<1>(argPair).getName()) + return false; + //Traverse the use-def chain of the argument + for (auto use : std::get<0>(argPair).getUses()) { + if (use.getOwner().getName() != std::get<1>(argPair).getName()) + return false; + } + } + + //// Compare operations (simplified - real implementation would be more complex) + //if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) + // return false; + + //// For a full implementation, you'd need more sophisticated operation comparison + //// based on operands, attributes, and result types + } + + return true; +} + +// Helper to check if indexing maps are equivalent +bool areIndexingMapsEquivalent(ArrayAttr firstMaps, ArrayAttr secondMaps) { + if (firstMaps.size() != secondMaps.size()) + return false; + + for (auto mapPair : llvm::zip(firstMaps, secondMaps)) { + auto firstMap = std::get<0>(mapPair).cast().getValue(); + auto secondMap = std::get<1>(mapPair).cast().getValue(); + + if (firstMap != secondMap) + return false; + } + + return true; +} + +// Helper to check if iterator types are equivalent +bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { + if (firstTypes.size() != secondTypes.size()) + return false; + + for (auto typePair : llvm::zip(firstTypes, secondTypes)) { + auto firstType = std::get<0>(typePair).cast().getValue(); + auto secondType = std::get<1>(typePair).cast().getValue(); + + if (firstType != secondType) + return false; + } + + return true; +} + +// Check if a linalg.generic operation matches a kernel.defn in a collection +FailureOr matchGenericWithDefn( + GenericOp genericOp, + kernel::DefnCollectionOp collectionOp) { + + // Get attributes from the generic operation + ArrayAttr indexingMaps = genericOp.getIndexingMapsAttr(); + ArrayAttr iteratorTypes = genericOp.getIteratorTypesAttr(); + unsigned numInputs = genericOp.getNumDpsInputs(); + unsigned numOutputs = genericOp.getNumDpsInits(); + + // Walk through each defn in the collection + for (Operation &op : collectionOp.getDefns()) { + auto defnOp = cast(op); + StringAttr opName = defnOp.getNameAttr(); + + // Check for linalg.generic in the defn's body + bool foundMatch = false; + defnOp.getBody().walk([&](GenericOp candidateOp) { + // Skip if already found a match + if (foundMatch) + return; + + // Check if this linalg.generic matches our target + if (candidateOp.getNumDpsInputs() == numInputs && + candidateOp.getNumDpsInits() == numOutputs && + //DONE: Generalize to a single dialect, with no special ops + //TODO: Indexing maps and orders might differ + //TODO: More complex case- where extra loops exists around the ops we have + //TODO: Custom cost model ? + //TODO: Constants might require special handling such as bounds + //IDEA: Descheduling / removing tiles + int numOfIndexingMaps = indexingMaps.size(); + int combinations = calculate_combinations(numOfIndexingMaps); + int calculatedCombinations(int numOfPos) { + //Calculate factorial of numOfPos + int result = 1; + for (int i = 1; i <= numOfPos; i++) { + result *= i; + } + return result; + } + areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && + areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && + areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) { + foundMatch = true; + } + }); + + if (foundMatch) + return opName.str(); + } + + return failure(); +} + +// Rewrite pattern to convert linalg.generic to kernel ops +class LinalgGenericToKernelPattern : public OpRewritePattern { +public: + LinalgGenericToKernelPattern(MLIRContext *context, + kernel::DefnCollectionOp collectionOp) + : OpRewritePattern(context), collectionOp(collectionOp) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + // Try to match with a defn in the collection + auto matchResult = matchGenericWithDefn(genericOp, collectionOp); + if (failed(matchResult)) + return failure(); + + std::string opName = *matchResult; + + // Create the appropriate kernel operation based on the matched pattern + if (opName == "Kernel_gemm") { + // Get inputs and outputs + Value outputTensor = genericOp.getDpsInitOperand(0)->get(); + Value inputA = genericOp.getDpsInputOperand(0)->get(); + Value inputB = genericOp.getDpsInputOperand(1)->get(); + + // Default alpha and beta values (could be extracted from pattern) + FloatAttr alpha = rewriter.getF32FloatAttr(1.0); + FloatAttr beta = rewriter.getF32FloatAttr(0.0); + + // Create the kernel.gemm operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), + outputTensor, inputA, inputB, alpha, beta); + + return success(); + } + else if (opName == "Kernel_batched_gemm") { + // Get inputs and outputs + Value outputTensor = genericOp.getDpsInitOperand(0)->get(); + Value inputA = genericOp.getDpsInputOperand(0)->get(); + Value inputB = genericOp.getDpsInputOperand(1)->get(); + + // Default alpha and beta values + FloatAttr alpha = rewriter.getF32FloatAttr(1.0); + FloatAttr beta = rewriter.getF32FloatAttr(0.0); + + // Create the kernel.batched_gemm operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), + outputTensor, inputA, inputB, alpha, beta); + + return success(); + } + else if (opName == "Kernel_iamax") { + // Get input + Value input = genericOp.getDpsInputOperand(0)->get(); + + // Create the kernel.iamax operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), input); + + return success(); + } + else if (opName == "Kernel_iamin") { + // Get input + Value input = genericOp.getDpsInputOperand(0)->get(); + + // Create the kernel.iamin operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), input); + + return success(); + } + else if (opName == "Kernel_asum") { + // Get input + Value input = genericOp.getDpsInputOperand(0)->get(); + + // Create the kernel.asum operation + rewriter.replaceOpWithNewOp( + genericOp, genericOp.getResultTypes(), input); + + return success(); + } + + return failure(); + } + +private: + kernel::DefnCollectionOp collectionOp; +}; + +// Pass to apply the rewrite pattern +class LinalgToKernelPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgToKernelPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + + // Find the kernel.defn_collection in the module + kernel::DefnCollectionOp collectionOp; + module.walk([&](kernel::DefnCollectionOp op) { + collectionOp = op; + return WalkResult::interrupt(); + }); + + if (!collectionOp) { + module.emitError("No kernel.defn_collection found in module"); + return signalPassFailure(); + } + + // Apply the rewrite pattern + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), collectionOp); + + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +// Create a pass to convert linalg.generic to kernel +std::unique_ptr createLinalgToKernelPass() { + return std::make_unique(); +} + +// Register the pass +void registerLinalgToKernelPasses() { + PassRegistration("linalg-to-kernel", + "Convert linalg.generic to kernel operations"); +} \ No newline at end of file diff --git a/generic_solver/CublasOps.td b/generic_solver/CublasOps.td new file mode 100644 index 000000000000..56aaebba0766 --- /dev/null +++ b/generic_solver/CublasOps.td @@ -0,0 +1,85 @@ +//===- KernelOps.td - kernel dialect operation definitions ---*- tablegen -*-===// +// +// This file defines the kernel operation definitions in TableGen format. +// +//===----------------------------------------------------------------------===// + +#ifndef kernel_OPS +#define kernel_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" + +//===----------------------------------------------------------------------===// +// kernel dialect definition +//===----------------------------------------------------------------------===// + +def Kernel_Dialect : Dialect { + let name = "kernel"; + let cppNamespace = "::mlir::kernel"; + let description = [{ + The kernel dialect provides operations for NVIDIA kernel matrix multiplication + routines, including standard and batched GEMM operations. + }]; +} + +//===----------------------------------------------------------------------===// +// Base class for kernel dialect operations +//===----------------------------------------------------------------------===// + +class Kernel_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// kernel ops instantiation collection +//===----------------------------------------------------------------------===// + +def Opinst_DefnCollection : Op { + let summary = "Collection of operation definitions"; + let description = [{ + A collection of operation definitions that can be referenced elsewhere. + This operation serves as a container for multiple operation definitions. + }]; + + let regions = (region SizedRegion<1>:$defns); + + let assemblyFormat = [{ + $defns attr-dict + }]; +} + +def Opinst_Defn : Op { + let summary = "Definition of an operation"; + let description = [{ + A definition of an operation with inputs and arbitrary body code. + Can contain either literal code or a linalg.generic representation. + }]; + + let arguments = (ins + StrAttr:$name, + Variadic:$inputs + ); + + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = [{ + $name `(` $inputs `)` $body attr-dict `:` functional-type($inputs, results) + }]; +} + +//===----------------------------------------------------------------------===// +// Example pattern representation +//===----------------------------------------------------------------------===// + +// Patterns for gemm and batched_gemm expressed in a mathematical notation. +// These are informational and would be used by pattern matchers. + +// Standard GEMM pattern: C(i,k) += alpha * A(i,j) * B(j,k) +// Batched GEMM pattern: C(N, i,k) += alpha * A(N, i,j) * B(N, j,k) + +// Index of max absolute value pattern: result = argmax_i |x_i| +// Index of min absolute value pattern: result = argmin_i |x_i| +// Sum of absolute values pattern: result = sum_i |x_i| + +#endif // kernel_OPS \ No newline at end of file diff --git a/generic_solver/example.mlir b/generic_solver/example.mlir new file mode 100644 index 000000000000..1dade3ef3afd --- /dev/null +++ b/generic_solver/example.mlir @@ -0,0 +1,49 @@ +//RUN: polygeist-opt --linalg-to-kernel="kernel-library-path=/home/arjaiswal/Polygeist/generic_solver/kernel_library.mlir" -allow-unregistered-dialect generic_solver/example.mlir +// Example MLIR module demonstrating kernel operations and their linalg.generic representations +module { + //Func that uses simple gemm + func.func @simple_gemm(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i, j, k) -> (i, k)>, // A(i,k) + affine_map<(i, j, k) -> (k, j)>, // B(k,j) + affine_map<(i, j, k) -> (i, j)> // C(i,j) + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } + + // Function that uses iamin (index of minimum absolute value) + func.func @find_min_abs_index(%X: tensor, %init: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (index) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%init : tensor) { + ^bb0(%in: f32, %out: i32): + %idx = linalg.index 0 : index + %abs_val = math.absf %in : f32 + %curr_min_idx = arith.index_cast %out : i32 to index + %curr_min = tensor.extract %X[%curr_min_idx] : tensor + %curr_min_abs = math.absf %curr_min : f32 + %cmp = arith.cmpf olt, %abs_val, %curr_min_abs : f32 + %new_idx = arith.select %cmp, %idx, %curr_min_idx : index + %result = arith.index_cast %new_idx : index to i32 + linalg.yield %result : i32 + } -> tensor + return %result : tensor + } + +} \ No newline at end of file diff --git a/generic_solver/kernel_library.mlir b/generic_solver/kernel_library.mlir new file mode 100644 index 000000000000..fd4fd6a48a70 --- /dev/null +++ b/generic_solver/kernel_library.mlir @@ -0,0 +1,218 @@ +// Kernel Library - Reusable kernel definitions +// This file contains a collection of kernel definitions that can be loaded +// by the linalg-to-kernel pass and applied to different MLIR modules. + +module { + // Collection of kernel operation definitions + kernel.defn_collection { + + // Simple GEMM operation definition with linalg.generic representation + kernel.defn @simple_gemm_linalg(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // Simple matrix multiplication: C = A * B + C + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // Scaled GEMM operation definition with alpha and beta coefficients + kernel.defn @gemm_linalg(%A: tensor, %B: tensor, %C: tensor, %alpha: f32, %beta: f32) -> tensor { + // GEMM with scaling: C = alpha * A * B + beta * C + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %scaled = arith.mulf %product, %alpha : f32 + %scaled_c = arith.mulf %c, %beta : f32 + %result = arith.addf %scaled, %scaled_c : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // Alpha-scaled GEMM accumulation (matches the second operation in the user's pattern) + kernel.defn @alpha_gemm_accumulate(%A: tensor, %B: tensor, %C: tensor, %alpha: f64) -> tensor { + // Matrix multiplication with alpha scaling: C += alpha * A * B + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %6 = arith.mulf %alpha, %in : f64 + %7 = arith.mulf %6, %in_0 : f64 + %8 = arith.addf %out, %7 : f64 + linalg.yield %8 : f64 + } -> tensor + kernel.yield %result : tensor + } + + // Element-wise beta scaling (matches the first operation in the user's pattern) + kernel.defn @beta_scale(%C: tensor, %beta: f64) -> tensor { + // Element-wise scaling: C = beta * C + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d1, d0)> + ], + iterator_types = ["parallel", "parallel"] + } outs(%C : tensor) { + ^bb0(%out: f64): + %6 = arith.mulf %out, %beta : f64 + linalg.yield %6 : f64 + } -> tensor + kernel.yield %result : tensor + } + + // Matrix multiplication with alpha scaling (second operation standalone) + kernel.defn @gemm_alpha_only(%A: tensor, %B: tensor, %C: tensor, %alpha: f64) -> tensor { + // Matrix multiplication: C += alpha * A * B + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d2, d0)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%in: f64, %in_0: f64, %out: f64): + %6 = arith.mulf %alpha, %in : f64 + %7 = arith.mulf %6, %in_0 : f64 + %8 = arith.addf %out, %7 : f64 + linalg.yield %8 : f64 + } -> tensor + kernel.yield %result : tensor + } + + // Sum of absolute values operation (ASUM) + kernel.defn @asum_linalg(%X: tensor, %init: tensor) -> tensor { + // Sum of absolute values: result = sum_i |x_i| + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%init : tensor) { + ^bb0(%in: f32, %out: f32): + %abs_val = math.absf %in : f32 + %result = arith.addf %abs_val, %out : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // Vector dot product + kernel.defn @dot_linalg(%X: tensor, %Y: tensor, %init: tensor) -> tensor { + // Dot product: result = sum_i x_i * y_i + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X, %Y : tensor, tensor) + outs(%init : tensor) { + ^bb0(%x: f32, %y: f32, %out: f32): + %product = arith.mulf %x, %y : f32 + %result = arith.addf %product, %out : f32 + linalg.yield %result : f32 + } -> tensor + kernel.yield %result : tensor + } + + // Index of maximum absolute value operation definition with linalg.generic representation + kernel.defn @iamax_linalg(%X: tensor, %init: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (index) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%init : tensor) { + ^bb0(%in: f32, %out: i32): + %idx = linalg.index 0 : index + %abs_val = math.absf %in : f32 + %curr_max_idx = arith.index_cast %out : i32 to index + %curr_max = tensor.extract %X[%curr_max_idx] : tensor + %curr_max_abs = math.absf %curr_max : f32 + %cmp = arith.cmpf ogt, %abs_val, %curr_max_abs : f32 + %new_idx = arith.select %cmp, %idx, %curr_max_idx : index + %result = arith.index_cast %new_idx : index to i32 + linalg.yield %result : i32 + } -> tensor + kernel.yield %result : tensor + } + + // General Matrix-Vector Multiply (GEMV) + kernel.defn @gemv_simple(%A: tensor, %x: tensor, %y: tensor) -> tensor { + // Simple matrix-vector multiplication: y += A * x + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d1, d0)>, // Matrix A[d0, d1] + affine_map<(d0, d1) -> (d0)>, // Vector x[d1] + affine_map<(d0, d1) -> (d1)> // Vector y[d0] + ], + iterator_types = ["parallel", "reduction"] + } ins(%A, %x : tensor, tensor) + outs(%y : tensor) { + ^bb0(%a: f64, %x_val: f64, %y_val: f64): + %product = arith.mulf %a, %x_val : f64 + %result = arith.addf %y_val, %product : f64 + linalg.yield %result : f64 + } -> tensor + kernel.yield %result : tensor + } + + // Index of minimum absolute value operation definition with linalg.generic representation + kernel.defn @iamin_linalg(%X: tensor, %init: tensor) -> tensor { + // Implementation using linalg.generic + %result = linalg.generic { + indexing_maps = [ + affine_map<(i) -> (i)>, // Input vector + affine_map<(i) -> ()> // Result scalar (index) + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%init : tensor) { + ^bb0(%in: f32, %out: i32): + %idx = linalg.index 0 : index + %abs_val = math.absf %in : f32 + %curr_min_idx = arith.index_cast %out : i32 to index + %curr_min = tensor.extract %X[%curr_min_idx] : tensor + %curr_min_abs = math.absf %curr_min : f32 + %cmp = arith.cmpf olt, %abs_val, %curr_min_abs : f32 + %new_idx = arith.select %cmp, %idx, %curr_min_idx : index + %result = arith.index_cast %new_idx : index to i32 + linalg.yield %result : i32 + } -> tensor + kernel.yield %result : tensor + } + } +} \ No newline at end of file diff --git a/generic_solver/test_input_simple.mlir b/generic_solver/test_input_simple.mlir new file mode 100644 index 000000000000..8fa0e6df4edf --- /dev/null +++ b/generic_solver/test_input_simple.mlir @@ -0,0 +1,71 @@ +// Test input file - contains linalg.generic operations to be matched +// This file does NOT contain kernel.defn_collection - those will be loaded externally + +module { + // Function that performs simple matrix multiplication + func.func @simple_gemm(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // This linalg.generic should match @simple_gemm_linalg from kernel_library.mlir + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %product = arith.mulf %a, %b : f32 + %result = arith.addf %product, %c : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } + + // Function that computes sum of absolute values + func.func @compute_asum(%X: tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // This linalg.generic should match @asum_linalg from kernel_library.mlir + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X : tensor) + outs(%fill : tensor) { + ^bb0(%in: f32, %out: f32): + %abs_val = math.absf %in : f32 + %result = arith.addf %abs_val, %out : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } + + // Function that computes dot product + func.func @compute_dot(%X: tensor, %Y: tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor) -> tensor + + // This linalg.generic should match @dot_linalg from kernel_library.mlir + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()> + ], + iterator_types = ["reduction"] + } ins(%X, %Y : tensor, tensor) + outs(%fill : tensor) { + ^bb0(%x: f32, %y: f32, %out: f32): + %product = arith.mulf %x, %y : f32 + %result = arith.addf %product, %out : f32 + linalg.yield %result : f32 + } -> tensor + return %result : tensor + } +} \ No newline at end of file diff --git a/include/polygeist/CMakeLists.txt b/include/polygeist/CMakeLists.txt index efcf93f70329..06fb9a05da90 100644 --- a/include/polygeist/CMakeLists.txt +++ b/include/polygeist/CMakeLists.txt @@ -2,4 +2,5 @@ add_mlir_dialect(PolygeistOps polygeist) add_mlir_doc(PolygeistDialect -gen-dialect-doc PolygeistDialect Polygeist/) add_mlir_doc(PolygeistOps -gen-op-doc PolygeistOps Polygeist/) -add_subdirectory(Passes) \ No newline at end of file +add_subdirectory(Passes) +add_subdirectory(Kernel) \ No newline at end of file diff --git a/include/polygeist/Kernel/CMakeLists.txt b/include/polygeist/Kernel/CMakeLists.txt new file mode 100644 index 000000000000..6bc7f03a564c --- /dev/null +++ b/include/polygeist/Kernel/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(KernelOps kernel) \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelDialect.h b/include/polygeist/Kernel/KernelDialect.h new file mode 100644 index 000000000000..6dbf888f97fc --- /dev/null +++ b/include/polygeist/Kernel/KernelDialect.h @@ -0,0 +1,25 @@ +//===- KernelDialect.h - Kernel dialect declaration -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef POLYGEIST_KERNEL_KERNELDIALECT_H +#define POLYGEIST_KERNEL_KERNELDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +namespace mlir { +namespace polygeist { +namespace kernel { + +} // namespace kernel +} // namespace polygeist +} // namespace mlir + +#include "polygeist/Kernel/KernelOpsDialect.h.inc" + +#endif // POLYGEIST_KERNEL_KERNELDIALECT_H \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelDialect.td b/include/polygeist/Kernel/KernelDialect.td new file mode 100644 index 000000000000..68ffc856b65f --- /dev/null +++ b/include/polygeist/Kernel/KernelDialect.td @@ -0,0 +1,36 @@ +//===- KernelDialect.td - Kernel dialect definition -------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef KERNEL_DIALECT +#define KERNEL_DIALECT + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Kernel dialect definition +//===----------------------------------------------------------------------===// + +def Kernel_Dialect : Dialect { + let name = "kernel"; + let cppNamespace = "::mlir::polygeist::kernel"; + let description = [{ + The kernel dialect provides operations for NVIDIA kernel matrix multiplication + routines, including standard and batched GEMM operations. This dialect enables + representation and optimization of high-performance linear algebra kernels + within the Polygeist infrastructure. + }]; +} + +//===----------------------------------------------------------------------===// +// Base class for kernel dialect operations +//===----------------------------------------------------------------------===// + +class Kernel_Op traits = []> : + Op; + +#endif // KERNEL_DIALECT \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelOps.h b/include/polygeist/Kernel/KernelOps.h new file mode 100644 index 000000000000..966ef77d6379 --- /dev/null +++ b/include/polygeist/Kernel/KernelOps.h @@ -0,0 +1,32 @@ +//===- KernelOps.h - Kernel dialect operations ------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef POLYGEIST_KERNEL_KERNELOPS_H +#define POLYGEIST_KERNEL_KERNELOPS_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "polygeist/Kernel/KernelDialect.h" + +namespace mlir { +namespace polygeist { +namespace kernel { + +} // namespace kernel +} // namespace polygeist +} // namespace mlir + +#define GET_OP_CLASSES +#include "polygeist/Kernel/KernelOps.h.inc" + +#endif // POLYGEIST_KERNEL_KERNELOPS_H \ No newline at end of file diff --git a/include/polygeist/Kernel/KernelOps.td b/include/polygeist/Kernel/KernelOps.td new file mode 100644 index 000000000000..aa5c758cf179 --- /dev/null +++ b/include/polygeist/Kernel/KernelOps.td @@ -0,0 +1,200 @@ +//===- KernelOps.td - Kernel dialect operation definitions -*-- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef KERNEL_OPS +#define KERNEL_OPS + +include "polygeist/Kernel/KernelDialect.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/OpAsmInterface.td" + +//===----------------------------------------------------------------------===// +// Kernel operation definitions +//===----------------------------------------------------------------------===// + +def Kernel_DefnCollectionOp : Kernel_Op<"defn_collection", [NoTerminator]> { + let summary = "Collection of kernel operation definitions"; + let description = [{ + A collection of operation definitions that can be referenced elsewhere. + This operation serves as a container for multiple kernel operation definitions, + enabling modular organization of kernel implementations. + }]; + + let regions = (region SizedRegion<1>:$defns); + + let assemblyFormat = [{ + $defns attr-dict + }]; +} + +def Kernel_DefnOp : Kernel_Op<"defn", [ + AffineScope, + AutomaticAllocationScope, + IsolatedFromAbove, + FunctionOpInterface, + Symbol +]> { + let summary = "Definition of a kernel operation"; + let description = [{ + A definition of a kernel operation with inputs and arbitrary body code. + Can contain either literal CUDA/HIP code or a linalg.generic representation + for high-performance linear algebra operations. + + This operation is particularly useful for defining custom GEMM variants, + batched operations, and other specialized linear algebra kernels. + + Example: + ```mlir + kernel.defn @custom_gemm(%A: memref, %B: memref, + %C: memref, %alpha: f32) -> tensor { + // Kernel implementation + kernel.yield %some_result : tensor + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + + let hasCustomAssemblyFormat = 1; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Returns the argument types of this kernel. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this kernel. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the region on the current operation that is callable. + ::mlir::Region *getCallableRegion() { return &getBody(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return getBody().empty(); } + }]; +} + +//===----------------------------------------------------------------------===// +// LaunchOp +//===----------------------------------------------------------------------===// + +def Kernel_LaunchOp : Kernel_Op<"launch", + [CallOpInterface, MemRefsNormalizable, + DeclareOpInterfaceMethods]> { + let summary = "kernel launch operation"; + let description = [{ + The `kernel.launch` operation represents a launch of a kernel that is + within the same symbol scope as the launch. The operands and result types of + the launch must match the specified kernel type. The kernel is encoded as a + symbol reference attribute named "kernel". + + Example: + + ```mlir + %result = kernel.launch @custom_gemm(%A, %B, %C, %alpha) : (memref, memref, memref, f32) -> tensor + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$kernel, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "DefnOp":$kernel, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("kernel", SymbolRefAttr::get(kernel)); + $_state.addTypes(kernel.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$kernel, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("kernel", kernel); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$kernel, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(kernel), results, operands); + }]>, + OpBuilder<(ins "StringRef":$kernel, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), kernel), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getKernelType(); + + /// Get the argument operands to the launched kernel. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the kernel of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("kernel"); + } + + /// Set the kernel for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("kernel", callee.get()); + } + }]; + + let assemblyFormat = [{ + $kernel `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def Kernel_YieldOp : Kernel_Op<"yield", [Pure, HasParent<"DefnOp">, + MemRefsNormalizable, ReturnLike, Terminator]> { + let summary = "Terminator for kernel.defn operation"; + let description = [{ + The `kernel.yield` operation terminates regions within kernel operations. + It optionally returns values from the kernel definition. + }]; + + let arguments = (ins Variadic:$operands); + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + + let builders = [ + OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]> + ]; + + let hasVerifier = 1; +} + +#endif // KERNEL_OPS \ No newline at end of file diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index 92c5812e8c4c..e70660153540 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -32,6 +32,9 @@ std::unique_ptr createOpenMPOptPass(); std::unique_ptr createCanonicalizeForPass(); std::unique_ptr createRaiseSCFToAffinePass(); std::unique_ptr createRaiseAffineToLinalgPass(); +std::unique_ptr createRaiseAffineToLinalgPipelinePass(); +std::unique_ptr createLinalgDebufferizePass(); +std::unique_ptr createRemoveIterArgsPass(); std::unique_ptr createCPUifyPass(StringRef method = ""); std::unique_ptr createBarrierRemovalContinuation(); std::unique_ptr detectReductionPass(); @@ -71,6 +74,9 @@ createGpuSerializeToHsacoPass(StringRef arch, StringRef features, int llvmOptLevel, int hsaOptLevel, std::string rocmPath, bool outputIntermediate); +std::unique_ptr createLinalgToKernelPass(); +std::unique_ptr createLinalgToKernelPass(const std::string& kernelLibraryPath); + void registerGpuSerializeToCubinPass(); void registerGpuSerializeToHsacoPass(); @@ -96,6 +102,11 @@ namespace omp { class OpenMPDialect; } // end namespace omp +namespace polygeist { +namespace kernel { +class KernelDialect; +} // end namespace kernel +} namespace polygeist { class PolygeistDialect; } // end namespace polygeist @@ -128,6 +139,18 @@ namespace linalg { class LinalgDialect; } +namespace tensor { +class TensorDialect; +} + +namespace bufferization { +class BufferizationDialect; +} + +namespace Tensor { +class TensorDialect; +} + namespace LLVM { class LLVMDialect; } diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 5c17a9d6dc25..368eb59d28ab 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -151,12 +151,43 @@ def SCFRaiseToAffine : Pass<"raise-scf-to-affine"> { ]; } +def RemoveIterArgs : Pass<"remove-iter-args"> { + let summary = "Remove scf iter args"; + let constructor = "mlir::polygeist::createRemoveIterArgsPass()"; + let dependentDialects = [ + "affine::AffineDialect", + "scf::SCFDialect", + ]; +} + +def LinalgDebufferize : Pass<"linalg-debufferize"> { + let summary = "Raise affine to linalg"; + let constructor = "mlir::polygeist::createLinalgDebufferizePass()"; + let dependentDialects = [ + "affine::AffineDialect", + "linalg::LinalgDialect", + "bufferization::BufferizationDialect", + "polygeist::PolygeistDialect", + ]; +} + def AffineRaiseToLinalg : Pass<"raise-affine-to-linalg"> { let summary = "Raise affine to linalg"; let constructor = "mlir::polygeist::createRaiseAffineToLinalgPass()"; let dependentDialects = [ "affine::AffineDialect", "linalg::LinalgDialect", + "polygeist::PolygeistDialect", + ]; +} + +def AffineRaiseToLinalgPipeline : Pass<"raise-affine-to-linalg-pipeline"> { + let summary = "Pipeline: affine-parallelize followed by raise-affine-to-linalg"; + let constructor = "mlir::polygeist::createRaiseAffineToLinalgPipelinePass()"; + let dependentDialects = [ + "affine::AffineDialect", + "linalg::LinalgDialect", + "polygeist::PolygeistDialect", ]; } @@ -234,6 +265,54 @@ def RemoveTrivialUse : Pass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } +def LinalgToKernel : Pass<"linalg-to-kernel", "mlir::ModuleOp"> { + let summary = "Convert linalg.generic operations to kernel operations by matching with kernel.defn patterns"; + let description = [{ + This pass matches linalg.generic operations against patterns defined in + kernel.defn_collection operations and converts them to the corresponding + specialized kernel operations (e.g., kernel.gemm, kernel.batched_gemm). + + The pass performs semantic matching of linalg.generic operations by: + - Comparing indexing maps and iterator types + - Matching the operation structure within regions + - Checking input/output operand counts + + Example transformation: + ```mlir + // Input: linalg.generic performing matrix multiplication + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%A, %B : tensor, tensor) + outs(%C : tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %mul = arith.mulf %a, %b : f32 + %add = arith.addf %mul, %c : f32 + linalg.yield %add : f32 + } -> tensor + + // Output: Specialized kernel operation + %result = kernel.gemm %C, %A, %B, %alpha, %beta : tensor + ``` + }]; + let constructor = "mlir::polygeist::createLinalgToKernelPass()"; + let dependentDialects = [ + "linalg::LinalgDialect", + "polygeist::kernel::KernelDialect", + "tensor::TensorDialect", + "arith::ArithDialect", + "bufferization::BufferizationDialect", + ]; + let options = [ + Option<"kernelLibraryPath", "kernel-library-path", "std::string", + /*default=*/"\"\"", + "Path to external MLIR file containing kernel.defn_collection definitions. " + "If empty, looks for kernel.defn_collection in the input module."> + ]; +} + def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { let summary = "Convert scalar and vector operations from the Standard to the " "LLVM dialect"; diff --git a/include/polygeist/PolygeistOps.td b/include/polygeist/PolygeistOps.td index 159f6c144947..ff59deb22bbd 100644 --- a/include/polygeist/PolygeistOps.td +++ b/include/polygeist/PolygeistOps.td @@ -259,4 +259,21 @@ def TypeAlignOp : Polygeist_Op<"typeAlign", [Pure]> { let hasFolder = 1; let hasCanonicalizer = 1; } + +def SubmapOp : Polygeist_Op<"submap", [Pure, ViewLikeOpInterface]> { + let arguments = (ins Arg:$memref, + Variadic:$indices_and_sizes, + AffineMapAttr:$map + ); + let results = (outs AnyMemRef : $result); + let hasFolder = 1; + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + ::mlir::ValueRange getSymbols() { return getOperands().slice(1, getMap().getNumSymbols()); } + ::mlir::ValueRange getSizes() { return getOperands().slice(getMap().getNumSymbols()+1, getType().getShape().size()); } + ::mlir::Value getViewSource() { return getMemref(); } + }]; +} + #endif // POLYGEIST_OPS diff --git a/lib/polygeist/CMakeLists.txt b/lib/polygeist/CMakeLists.txt index 88aea0de4dd5..b2a410a77872 100644 --- a/lib/polygeist/CMakeLists.txt +++ b/lib/polygeist/CMakeLists.txt @@ -19,3 +19,4 @@ MLIRSCFTransforms ) add_subdirectory(Passes) add_subdirectory(ExecutionEngine) +add_subdirectory(Kernel) diff --git a/lib/polygeist/Kernel/CMakeLists.txt b/lib/polygeist/Kernel/CMakeLists.txt new file mode 100644 index 000000000000..371724504a5e --- /dev/null +++ b/lib/polygeist/Kernel/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRPolygeistKernel + KernelDialect.cpp + KernelOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/polygeist/Kernel + + DEPENDS + MLIRKernelOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRMemRefDialect + MLIRArithDialect + MLIRFuncDialect + MLIRControlFlowInterfaces + MLIRSideEffectInterfaces + MLIRSupport +) \ No newline at end of file diff --git a/lib/polygeist/Kernel/KernelDialect.cpp b/lib/polygeist/Kernel/KernelDialect.cpp new file mode 100644 index 000000000000..0e239ff2565c --- /dev/null +++ b/lib/polygeist/Kernel/KernelDialect.cpp @@ -0,0 +1,33 @@ +//===- KernelDialect.cpp - Kernel dialect implementation --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "polygeist/Kernel/KernelDialect.h" +#include "polygeist/Kernel/KernelOps.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +using namespace mlir; +using namespace mlir::polygeist; +using namespace mlir::polygeist::kernel; + +#include "polygeist/Kernel/KernelOpsDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// Kernel dialect initialization +//===----------------------------------------------------------------------===// + +void KernelDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "polygeist/Kernel/KernelOps.cpp.inc" + >(); +} \ No newline at end of file diff --git a/lib/polygeist/Kernel/KernelOps.cpp b/lib/polygeist/Kernel/KernelOps.cpp new file mode 100644 index 000000000000..8ad84f79e6ea --- /dev/null +++ b/lib/polygeist/Kernel/KernelOps.cpp @@ -0,0 +1,150 @@ +//===- KernelOps.cpp - Kernel dialect operations ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "polygeist/Kernel/KernelOps.h" +#include "polygeist/Kernel/KernelDialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::polygeist; +using namespace mlir::polygeist::kernel; + +//===----------------------------------------------------------------------===// +// DefnOp +//===----------------------------------------------------------------------===// + +LogicalResult DefnOp::verify() { + // Check that the body region has exactly one block + if (!getBody().hasOneBlock()) + return emitOpError("body region must have exactly one block"); + + // The block can have any number of arguments + // No special verification needed for block arguments + + return success(); +} + +ParseResult DefnOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = [](Builder &builder, ArrayRef argTypes, + ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { + return builder.getFunctionType(argTypes, results); + }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void DefnOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +LogicalResult YieldOp::verify() { + auto defnOp = cast((*this)->getParentOp()); + + // The operand number and types must match the kernel signature. + const auto &results = defnOp.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing kernel (@" + << defnOp.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of yield operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match kernel result type (" + << results[i] << ")" + << " in kernel @" << defnOp.getName(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// LaunchOp +//===----------------------------------------------------------------------===// + +FunctionType LaunchOp::getKernelType() { + // Get the kernel symbol reference + auto kernelAttr = (*this)->getAttrOfType("kernel"); + if (!kernelAttr) + return nullptr; + + // Look up the kernel DefnOp in the symbol table + auto *symbolTableOp = (*this)->getParentWithTrait(); + if (!symbolTableOp) + return nullptr; + + auto kernelOp = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symbolTableOp, kernelAttr)); + if (!kernelOp) + return nullptr; + + return kernelOp.getFunctionType(); +} + +LogicalResult LaunchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the kernel attribute was specified. + auto kernelAttr = (*this)->getAttrOfType("kernel"); + if (!kernelAttr) + return emitOpError("requires a 'kernel' symbol reference attribute"); + + // Check that the kernel symbol exists and is a DefnOp. + auto kernelOp = symbolTable.lookupNearestSymbolFrom(*this, kernelAttr); + if (!kernelOp) + return emitOpError() << "'" << kernelAttr.getValue() + << "' does not reference a valid kernel"; + + // Verify that the operand and result types match the kernel signature. + auto kernelType = kernelOp.getFunctionType(); + if (kernelType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for kernel"); + + for (unsigned i = 0, e = kernelType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != kernelType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << kernelType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (kernelType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for kernel"); + + for (unsigned i = 0, e = kernelType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != kernelType.getResult(i)) + return emitOpError("result type mismatch: expected result type ") + << kernelType.getResult(i) << ", but provided " + << getResult(i).getType() << " for result number " << i; + + return success(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "polygeist/Kernel/KernelOps.cpp.inc" \ No newline at end of file diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index d9a60fbcce45..b203bdcce137 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -22,9 +22,11 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" @@ -39,7 +41,6 @@ using namespace mlir; using namespace polygeist; using namespace mlir::arith; - llvm::cl::opt BarrierOpt("barrier-opt", llvm::cl::init(true), llvm::cl::desc("Optimize barriers")); @@ -673,6 +674,8 @@ bool isCaptured(Value v, Operation *potentialUser = nullptr, for (auto u : v.getUsers()) { if (seenuse && u == potentialUser) *seenuse = true; + if (isa(u)) + continue; if (isa(u)) continue; @@ -815,25 +818,43 @@ bool mayAlias(Value v, Value v2) { isAlloca[1] = isStackAlloca(v2); isGlobal[1] = v2.getDefiningOp() || - v2.getDefiningOp(); + v2.getDefiningOp(); // Non-equivalent allocas/global's cannot conflict with each other if ((isAlloca[0] || isGlobal[0]) && (isAlloca[1] || isGlobal[1])) return false; - bool isArg[2]; - isArg[0] = v.isa() && - isa( - v.cast().getOwner()->getParentOp()); + bool isArg[2] = {false, false}; + bool isNoAliasArg[2] = {false, false}; + + if (auto ba = dyn_cast(v)) { + if (auto fn = dyn_cast(ba.getOwner()->getParentOp())) { + isArg[0] = true; + if (fn.getArgAttr(ba.getArgNumber(), LLVM::LLVMDialect::getNoAliasAttrName())) { + isNoAliasArg[0] = true; + } + } + } - isArg[1] = v.isa() && - isa( - v.cast().getOwner()->getParentOp()); + if (auto ba = dyn_cast(v2)) { + if (auto fn = dyn_cast(ba.getOwner()->getParentOp())) { + isArg[0] = true; + if (fn.getArgAttr(ba.getArgNumber(), LLVM::LLVMDialect::getNoAliasAttrName())) { + isNoAliasArg[0] = true; + } + } + } // Stack allocations cannot have been passed as an argument. if ((isAlloca[0] && isArg[1]) || (isAlloca[1] && isArg[0])) return false; + if ((isArg[0] && isNoAliasArg[1]) || (isArg[1] && isNoAliasArg[0])) + return false; + + if ((isGlobal[0] && isNoAliasArg[1]) || (isGlobal[1] && isNoAliasArg[0])) + return false; + // Non captured base allocas cannot conflict with another base value. if (isAlloca[0] && !isCaptured(v)) return false; @@ -4487,7 +4508,6 @@ struct MergeNestedAffineParallelIf return success(); } }; - struct MergeParallelInductions : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -4497,7 +4517,7 @@ struct MergeParallelInductions // Reductions are not supported yet. if (!op.getReductions().empty()) return failure(); - + auto getIndUsage = [&op](AffineExpr cst, ValueRange operands, std::map &indUsage, bool &legal) -> AffineExpr { @@ -5733,6 +5753,629 @@ struct MulDivMul : public OpRewritePattern { } }; +struct SubMapOpCanonicalize : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SubmapOp op, + PatternRewriter &rewriter) const override { + /// if submap %x is identity map and has the same size as the static size of + /// %x + ///. replace submap with memref.cast of memref<4x5xf32> to memref + /// %x = ... : memref<4x5xf32> + // %y = polygeist.submap %x(#identity_map, %constant_4, %constant_5) : + // memref<4x5xf32> -> memref + // + //. becomes + // + /// %x = ... : memref<4x5xf32> + // %y = memref.cast %x : memref<4x5xf32> -> memref + // + auto source_memref = op.getMemref(); + bool isIdentity = op.getMap().isIdentity(); + bool isInputSameDim = llvm::all_of( + llvm::zip_equal(op.getSizes(), + cast(source_memref.getType()).getShape()), + [&](auto pair) { + if (std::get<1>(pair) == -1) + return false; + APInt matched; + if (matchPattern(std::get<0>(pair), m_ConstantInt(&matched))) { + return std::get<1>(pair) == matched; + } + return false; + }); + if (isIdentity && isInputSameDim) { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getMemref()); + return success(); + } + if (auto sapOp = source_memref.getDefiningOp()) { + auto load_map = op.getMap(); + auto submap_map = sapOp.getMap(); + auto new_map = submap_map.compose(load_map); + SmallVector operands; + operands.append(op.getSymbols().begin(), op.getSymbols().end()); + operands.append(op.getSymbols().begin(), op.getSymbols().end()); + operands.append(op.getSizes().begin(), op.getSizes().end()); + rewriter.replaceOpWithNewOp( + op, op.getType(), sapOp.getMemref(), operands, new_map); + return success(); + } + return failure(); + } +}; + +struct StrideAndBound { + int64_t stride; + int64_t lowerBound; + unsigned dimOrSymbol; // Which dimension/symbol this applies to + bool isDimension; // true if dimension, false if symbol + + StrideAndBound(int64_t s, int64_t lb, unsigned idx, bool isDim) + : stride(s), lowerBound(lb), dimOrSymbol(idx), isDimension(isDim) {} +}; + +struct ExpressionAnalysis { + SmallVector coefficients; // Coefficients for dims/symbols + int64_t constantTerm = 0; // Pure constant term + + void addDimCoeff(unsigned dim, int64_t coeff) { + coefficients.emplace_back(coeff, 0, dim, true); + } + + void addSymCoeff(unsigned sym, int64_t coeff) { + coefficients.emplace_back(coeff, 0, sym, false); + } +}; + +// Recursively analyze an affine expression to extract coefficients and constants +static ExpressionAnalysis analyzeAffineExpression(AffineExpr expr) { + ExpressionAnalysis result; + + if (auto constExpr = expr.dyn_cast()) { + // Pure constant + result.constantTerm = constExpr.getValue(); + + } else if (auto dimExpr = expr.dyn_cast()) { + // Single dimension with coefficient 1 + result.addDimCoeff(dimExpr.getPosition(), 1); + + } else if (auto symExpr = expr.dyn_cast()) { + // Single symbol with coefficient 1 + result.addSymCoeff(symExpr.getPosition(), 1); + + } else if (auto binaryExpr = expr.dyn_cast()) { + auto lhs = binaryExpr.getLHS(); + auto rhs = binaryExpr.getRHS(); + + if (binaryExpr.getKind() == AffineExprKind::Add) { + // Addition: combine results from both sides + auto lhsAnalysis = analyzeAffineExpression(lhs); + auto rhsAnalysis = analyzeAffineExpression(rhs); + + result.coefficients.append(lhsAnalysis.coefficients); + result.coefficients.append(rhsAnalysis.coefficients); + result.constantTerm = lhsAnalysis.constantTerm + rhsAnalysis.constantTerm; + + } else if (binaryExpr.getKind() == AffineExprKind::Mul) { + // Multiplication: one side should be constant, other should be dim/symbol + auto lhsConst = lhs.dyn_cast(); + auto rhsConst = rhs.dyn_cast(); + + if (lhsConst && !rhsConst) { + // Constant * expr + auto rhsAnalysis = analyzeAffineExpression(rhs); + for (auto &coeff : rhsAnalysis.coefficients) { + coeff.stride *= lhsConst.getValue(); + } + result.coefficients = std::move(rhsAnalysis.coefficients); + result.constantTerm = rhsAnalysis.constantTerm * lhsConst.getValue(); + + } else if (rhsConst && !lhsConst) { + // expr * Constant + auto lhsAnalysis = analyzeAffineExpression(lhs); + for (auto &coeff : lhsAnalysis.coefficients) { + coeff.stride *= rhsConst.getValue(); + } + result.coefficients = std::move(lhsAnalysis.coefficients); + result.constantTerm = lhsAnalysis.constantTerm * rhsConst.getValue(); + + } else if (lhsConst && rhsConst) { + // Constant * Constant + result.constantTerm = lhsConst.getValue() * rhsConst.getValue(); + } + // Note: expr * expr is not affine, so we don't handle it + + } else if (binaryExpr.getKind() == AffineExprKind::Mod) { + // Modulo: more complex, for now just mark as having the base expression + auto lhsAnalysis = analyzeAffineExpression(lhs); + result.coefficients = std::move(lhsAnalysis.coefficients); + result.constantTerm = lhsAnalysis.constantTerm; + + } else if (binaryExpr.getKind() == AffineExprKind::FloorDiv || + binaryExpr.getKind() == AffineExprKind::CeilDiv) { + // Division: handle simple cases where RHS is constant + if (auto rhsConst = rhs.dyn_cast()) { + auto lhsAnalysis = analyzeAffineExpression(lhs); + for (auto &coeff : lhsAnalysis.coefficients) { + coeff.stride = coeff.stride / rhsConst.getValue(); + } + result.coefficients = std::move(lhsAnalysis.coefficients); + result.constantTerm = lhsAnalysis.constantTerm / rhsConst.getValue(); + } + } + } + + return result; +} + +struct MapAnalysis { + SmallVector outputAnalyses; + + // Get all unique strides from all outputs + SmallVector getAllStrides() const { + SmallVector strides; + llvm::DenseSet seen; + + for (const auto &analysis : outputAnalyses) { + for (const auto &coeff : analysis.coefficients) { + // TODO: Need to add a check that if more than one coeffs in an outputAnalysis + // then we need to return failure. + strides.push_back(coeff.stride); + } + } + return strides; + } + + // Get all lower bounds (constant terms) from all outputs + SmallVector getAllLowerBounds() const { + SmallVector bounds; + for (const auto &analysis : outputAnalyses) { + bounds.push_back(analysis.constantTerm); + } + return bounds; + } +}; + +// Main function to analyze an affine map +static MapAnalysis analyzeAffineMap(AffineMap map) { + MapAnalysis result; + + for (auto expr : map.getResults()) { + result.outputAnalyses.push_back(analyzeAffineExpression(expr)); + } + + return result; +} + +// Extract both strides and bounds +std::pair, SmallVector> +extractStridesAndBounds(AffineMap map) { + auto analysis = analyzeAffineMap(map); + return {analysis.getAllStrides(), analysis.getAllLowerBounds()}; +} + +// Helper function to check if an expression is a simple offset + stride pattern +static bool isSimpleOffsetStride(AffineExpr expr) { + // Check if expression is of the form: d0 + constant, d0 * constant + constant, etc. + if (auto dimExpr = expr.dyn_cast()) { + return true; // Simple dimension access + } + + if (auto constExpr = expr.dyn_cast()) { + return true; // Constant offset + } + + if (auto binaryExpr = expr.dyn_cast()) { + auto kind = binaryExpr.getKind(); + + // Allow simple addition and multiplication patterns + if (kind == AffineExprKind::Add || kind == AffineExprKind::Mul) { + return isSimpleOffsetStride(binaryExpr.getLHS()) && + isSimpleOffsetStride(binaryExpr.getRHS()); + } + + // Allow simple division by constants (for stride calculation) + if (kind == AffineExprKind::FloorDiv || kind == AffineExprKind::CeilDiv) { + if (auto rhsConst = binaryExpr.getRHS().dyn_cast()) { + return rhsConst.getValue() > 0 && isSimpleOffsetStride(binaryExpr.getLHS()); + } + } + } + + return false; +} + +// Main function to check if SubmapOp can be converted to SubViewOp +static bool canConvertSubmapToSubView(polygeist::SubmapOp submapOp) { + auto map = submapOp.getMap(); + auto sizes = submapOp.getSizes(); + auto symbols = submapOp.getSymbols(); + auto source_memref = submapOp.getMemref(); + + // 0. Only convert if map has symbols + if (submapOp.getMap().getNumSymbols() == 0) { + return false; + } + + // 1. Identity maps are always valid + if (map.isIdentity()) { + return true; + } + + // 2. Check if we can extract meaningful strides and bounds + auto [strides, lowerBounds] = extractStridesAndBounds(map); + if (strides.empty() || lowerBounds.empty()) { + return false; + } + + // 3. Ensure the number of results matches expected dimensions + if (map.getNumResults() != sizes.size()) { + return false; + } + + // 4. Check each expression in the map for complexity + for (auto expr : map.getResults()) { + if (!isSimpleOffsetStride(expr)) { + return false; + } + } + + // 5. Check for unsupported complex transformations + for (auto expr : map.getResults()) { + // Reject expressions that involve multiple dimensions in complex ways + if (auto binaryExpr = expr.dyn_cast()) { + // For now, reject modulo operations as they're hard to represent in SubView + if (binaryExpr.getKind() == AffineExprKind::Mod) { + return false; + } + + // Reject complex multi-dimensional expressions + if (binaryExpr.getKind() == AffineExprKind::Mul) { + auto lhs = binaryExpr.getLHS(); + auto rhs = binaryExpr.getRHS(); + + // Both sides are dimensions = complex interaction + if (lhs.isa() && rhs.isa()) { + return false; + } + + // Multiplication by symbols might be too complex for simple SubView + if (lhs.isa() || rhs.isa()) { + // Allow simple symbol multiplication, but check it's not too complex + if (!lhs.isa() && !rhs.isa()) { + return false; + } + } + } + } + } + + // 6. Check for rank-changing transformations that SubView can't handle + auto sourceType = source_memref.getType().cast(); + auto resultType = submapOp.getType().cast(); + + // SubView can do rank-reduction, but not rank-expansion + if (resultType.getRank() > sourceType.getRank()) { + return false; + } + + return true; +} + +// Convenience function to check and extract conversion info +struct SubmapToSubViewConversionInfo { + bool isValid; + SmallVector strides; + SmallVector offsets; + SmallVector sizes; + SmallVector dynamicOffsets; // For symbol-based offsets + + SubmapToSubViewConversionInfo() : isValid(false) {} +}; + +static SubmapToSubViewConversionInfo +analyzeSubmapToSubViewConversion(polygeist::SubmapOp submapOp) { + SubmapToSubViewConversionInfo info; + + if (!canConvertSubmapToSubView(submapOp)) { + return info; // isValid = false + } + + auto map = submapOp.getMap(); + auto [strides, lowerBounds] = extractStridesAndBounds(map); + + info.isValid = true; + info.strides = strides; + info.offsets = lowerBounds; + info.sizes.append(submapOp.getSizes().begin(), submapOp.getSizes().end()); + + return info; +} + + +struct SubmapToSubviewOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(polygeist::SubmapOp submapOp, + PatternRewriter &rewriter) const override { + auto conversionInfo = analyzeSubmapToSubViewConversion(submapOp); + if (!conversionInfo.isValid) + return failure(); + + SmallVector offsetValues, sizeValues, strideValues; + for (int64_t offset : conversionInfo.offsets) { + offsetValues.push_back(rewriter.getI64IntegerAttr(offset)); + } + for (int64_t stride : conversionInfo.strides) { + strideValues.push_back(rewriter.getI64IntegerAttr(stride)); + } + for (Value size : conversionInfo.sizes) { + sizeValues.push_back(size); + } + rewriter.replaceOpWithNewOp(submapOp, submapOp.getType(), submapOp.getMemref(), offsetValues, sizeValues, strideValues); + return success(); + } +}; + +// Enhanced analysis structure to handle symbols and transposes +struct EnhancedSubmapAnalysis { + bool isValid = false; + bool needsTranspose = false; + SmallVector permutation; // For transpose: [1,0] means swap dims + SmallVector offsets; // Mix of constants and symbol values + SmallVector strides; // Mix of constants and symbol values + SmallVector sizes; // From submapOp.getSizes() +}; + +// Helper to analyze affine expressions with symbol support +static bool analyzeExpressionWithSymbols(AffineExpr expr, unsigned expectedDim, + ValueRange symbolValues, + OpFoldResult &offset, OpFoldResult &stride, + unsigned &actualDim, OpBuilder &builder) { + offset = builder.getI64IntegerAttr(0); // Default offset = 0 + stride = builder.getI64IntegerAttr(1); // Default stride = 1 + actualDim = expectedDim; + + // Case 1: Simple dimension access: d0, d1, etc. + if (auto dimExpr = expr.dyn_cast()) { + actualDim = dimExpr.getPosition(); + return true; + } + + // Case 2: Constant (pure offset) + if (auto constExpr = expr.dyn_cast()) { + offset = builder.getI64IntegerAttr(constExpr.getValue()); + actualDim = 0; // Degenerate case + return true; + } + + // Case 3: Symbol (pure offset from symbol) + if (auto symbolExpr = expr.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + offset = symbolValues[symbolExpr.getPosition()]; + actualDim = 0; // Degenerate case + return true; + } + return false; + } + + // Case 4: Binary operations + if (auto binaryExpr = expr.dyn_cast()) { + auto lhs = binaryExpr.getLHS(); + auto rhs = binaryExpr.getRHS(); + + if (binaryExpr.getKind() == AffineExprKind::Add) { + // d0 + constant, d0 + symbol, constant + symbol, etc. + if (auto dimExpr = lhs.dyn_cast()) { + actualDim = dimExpr.getPosition(); + if (auto constExpr = rhs.dyn_cast()) { + offset = builder.getI64IntegerAttr(constExpr.getValue()); + return true; + } + if (auto symbolExpr = rhs.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + offset = symbolValues[symbolExpr.getPosition()]; + return true; + } + } + } + // Try reverse: constant + d0, symbol + d0 + if (auto dimExpr = rhs.dyn_cast()) { + actualDim = dimExpr.getPosition(); + if (auto constExpr = lhs.dyn_cast()) { + offset = builder.getI64IntegerAttr(constExpr.getValue()); + return true; + } + if (auto symbolExpr = lhs.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + offset = symbolValues[symbolExpr.getPosition()]; + return true; + } + } + } + } + + if (binaryExpr.getKind() == AffineExprKind::Mul) { + // d0 * constant, d0 * symbol + if (auto dimExpr = lhs.dyn_cast()) { + actualDim = dimExpr.getPosition(); + if (auto constExpr = rhs.dyn_cast()) { + stride = builder.getI64IntegerAttr(constExpr.getValue()); + return true; + } + if (auto symbolExpr = rhs.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + stride = symbolValues[symbolExpr.getPosition()]; + return true; + } + } + } + // Try reverse: constant * d0, symbol * d0 + if (auto dimExpr = rhs.dyn_cast()) { + actualDim = dimExpr.getPosition(); + if (auto constExpr = lhs.dyn_cast()) { + stride = builder.getI64IntegerAttr(constExpr.getValue()); + return true; + } + if (auto symbolExpr = lhs.dyn_cast()) { + if (symbolExpr.getPosition() < symbolValues.size()) { + stride = symbolValues[symbolExpr.getPosition()]; + return true; + } + } + } + } + } + + return false; +} + +// Enhanced analysis function +static EnhancedSubmapAnalysis analyzeEnhancedSubmap(polygeist::SubmapOp submapOp, + OpBuilder &builder) { + EnhancedSubmapAnalysis analysis; + auto map = submapOp.getMap(); + auto symbolValues = submapOp.getSymbols(); + auto sizes = submapOp.getSizes(); + auto sourceType = submapOp.getViewSource().getType().cast(); + int64_t sourceRank = sourceType.getRank(); + + // Only handle maps with reasonable complexity + if (map.getNumResults() == 0 || map.getNumResults() > 4) { + return analysis; + } + + // Initialize arrays with default values for all dimensions of source memref + SmallVector offsets(sourceRank, builder.getI64IntegerAttr(0)); + SmallVector strides(sourceRank, builder.getI64IntegerAttr(1)); + SmallVector resultSizes; + SmallVector actualDims; + + // Build default sizes from source memref shape + for (int64_t i = 0; i < sourceRank; ++i) { + int64_t dimSize = sourceType.getDimSize(i); + if (dimSize == ShapedType::kDynamic) { + // For dynamic dimensions, we need to use the actual size + Value dimSizeValue = builder.create( + submapOp.getLoc(), submapOp.getViewSource(), i); + resultSizes.push_back(dimSizeValue); + } else { + resultSizes.push_back(builder.getI64IntegerAttr(dimSize)); + } + } + + // Analyze each result expression and update corresponding dimension + for (unsigned i = 0; i < map.getNumResults(); ++i) { + auto expr = map.getResult(i); + OpFoldResult offset, stride; + unsigned actualDim; + + if (!analyzeExpressionWithSymbols(expr, i, symbolValues, offset, stride, + actualDim, builder)) { + return analysis; // Failed to analyze + } + + // Make sure actualDim is within bounds + if (actualDim >= sourceRank) { + return analysis; // Invalid dimension + } + + // Update the arrays for this dimension + offsets[actualDim] = offset; + strides[actualDim] = stride; + actualDims.push_back(actualDim); + } + + analysis.isValid = true; + analysis.offsets = std::move(offsets); + analysis.strides = std::move(strides); + + // Copy sizes - use provided sizes if available, otherwise use computed ones + if (sizes.size() == map.getNumResults()) { + for (auto size : sizes) { + analysis.sizes.push_back(size); + } + } else { + // Use default sizes for all dimensions + analysis.sizes = std::move(resultSizes); + } + + return analysis; +} + +// Enhanced pattern implementation +struct EnhancedSubmapToSubviewOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(polygeist::SubmapOp submapOp, + PatternRewriter &rewriter) const override { + auto analysis = analyzeEnhancedSubmap(submapOp, rewriter); + if (!analysis.isValid) { + return failure(); + } + + Value currentMemref = submapOp.getViewSource(); + Location loc = submapOp.getLoc(); + + // Step 1: Apply subview if we have non-trivial offsets/strides + bool hasNonTrivialSubview = false; + for (auto offset : analysis.offsets) { + if (auto attr = offset.dyn_cast()) { + if (auto intAttr = attr.dyn_cast()) { + if (intAttr.getInt() != 0) { + hasNonTrivialSubview = true; + break; + } + } + } else { + hasNonTrivialSubview = true; // Non-constant offset + break; + } + } + + for (auto stride : analysis.strides) { + if (auto attr = stride.dyn_cast()) { + if (auto intAttr = attr.dyn_cast()) { + if (intAttr.getInt() != 1) { + hasNonTrivialSubview = true; + break; + } + } + } else { + hasNonTrivialSubview = true; // Non-constant stride + break; + } + } + + if (hasNonTrivialSubview) { + // Create subview operation + auto subviewOp = rewriter.create( + loc, currentMemref, analysis.offsets, analysis.sizes, analysis.strides); + currentMemref = subviewOp.getResult(); + } + + // Step 2: Apply transpose if needed + if (analysis.needsTranspose) { + // Create transpose using linalg.transpose or memref.transpose + // For now, let's use a simple approach with linalg + SmallVector permutation = analysis.permutation; + + // Create transpose using linalg.transpose (if available) + // This is a simplified version - you might need to adjust based on available ops + auto transposeType = MemRefType::get( + submapOp.getType().cast().getShape(), + submapOp.getType().cast().getElementType()); + + // For simplicity, let's create an identity operation for now + // In practice, you'd want to create the actual transpose operation + currentMemref = currentMemref; // TODO: Implement actual transpose + } + + // Replace the original submap + rewriter.replaceOp(submapOp, currentMemref); + return success(); + } +}; + static llvm::cl::opt BufferElim("enable-buffer-elim", llvm::cl::init(true), llvm::cl::desc("Enable buffer elimination")); @@ -5764,7 +6407,6 @@ void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results, SimplifyDeadAllocV2, SimplifyDeadAllocV2, MulDivMul, MergeParallelInductions, - // RankReduction, AggressiveAllocaScopeInliner, InductiveVarRemoval>(context); } @@ -5880,3 +6522,186 @@ LogicalResult GetFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } + +class LoadSubMap final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineLoadOp op, + PatternRewriter &rewriter) const override { + auto subMapOp = op.getMemRef().getDefiningOp(); + if (!subMapOp) + return failure(); + + auto submap_map = subMapOp.getMap(); + auto submap_operands = subMapOp.getSymbols(); + auto source_memref = subMapOp.getMemref(); + + auto load_map = op.getAffineMap(); + auto load_operands = op.getMapOperands(); + + auto new_map = submap_map.compose(load_map); + + SmallVector operands; + operands.append(load_operands.begin(), + load_operands.begin() + load_map.getNumDims()); + operands.append(submap_operands.begin(), submap_operands.end()); + operands.append(load_operands.begin() + load_map.getNumDims(), + load_operands.end()); + + rewriter.replaceOpWithNewOp(op, source_memref, + new_map, operands); + return success(); + } +}; + +class StoreSubMap final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineStoreOp op, + PatternRewriter &rewriter) const override { + auto subMapOp = op.getMemRef().getDefiningOp(); + if (!subMapOp) + return failure(); + + auto submap_map = subMapOp.getMap(); + auto submap_operands = subMapOp.getSymbols(); + auto source_memref = subMapOp.getMemref(); + + auto load_map = op.getAffineMap(); + auto load_operands = op.getMapOperands(); + + auto new_map = submap_map.compose(load_map); + + SmallVector operands; + operands.append(load_operands.begin(), + load_operands.begin() + load_map.getNumDims()); + operands.append(submap_operands.begin(), submap_operands.end()); + operands.append(load_operands.begin() + load_map.getNumDims(), + load_operands.end()); + + rewriter.replaceOpWithNewOp( + op, op.getValue(), source_memref, new_map, operands); + return success(); + } +}; + +OpFoldResult mlir::polygeist::SubmapOp::fold( + mlir::polygeist::SubmapOp::FoldAdaptor adaptor) { + // TODO if submap is identity return nothing + // if submap of submap return new submap + return nullptr; +} + +class DimSubMap final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DimOp op, + PatternRewriter &rewriter) const override { + auto subMapOp = op.getSource().getDefiningOp(); + if (!subMapOp) + return failure(); + + auto idx = op.getIndex().getDefiningOp(); + if (!idx) + return failure(); + + rewriter.replaceOp(op, subMapOp.getSizes()[idx.value()]); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// LinalgGenericEliminateSubmaps Pattern +//===----------------------------------------------------------------------===// + +struct LinalgGenericEliminateSubmaps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { + bool hasSubmaps = false; + SmallVector newInputs; + SmallVector newOutputs; + SmallVector newIndexingMaps; + + // Get the indexing maps as AffineMap array + auto indexingMaps = genericOp.getIndexingMapsArray(); + + // Check inputs for submaps + for (auto [input, map] : llvm::zip(genericOp.getInputs(), indexingMaps)) { + if (auto submapOp = input.getDefiningOp()) { + // Skip submaps with symbols for now to avoid invalid map composition + if (submapOp.getMap().getNumSymbols() > 0) { + newInputs.push_back(input); + newIndexingMaps.push_back(map); + continue; + } + + hasSubmaps = true; + newInputs.push_back(submapOp.getViewSource()); + // Compose: submap_map.compose(linalg_map) → f(g(x)) + AffineMap composedMap = submapOp.getMap().compose(map); + newIndexingMaps.push_back(composedMap); + } else { + newInputs.push_back(input); + newIndexingMaps.push_back(map); + } + } + + // Check outputs for submaps + auto outputMaps = ArrayRef(indexingMaps).drop_front(genericOp.getInputs().size()); + for (auto [output, map] : llvm::zip(genericOp.getOutputs(), outputMaps)) { + if (auto submapOp = output.getDefiningOp()) { + // Skip submaps with symbols for now to avoid invalid map composition + if (submapOp.getMap().getNumSymbols() > 0) { + newOutputs.push_back(output); + newIndexingMaps.push_back(map); + continue; + } + + hasSubmaps = true; + newOutputs.push_back(submapOp.getViewSource()); + // Compose: submap_map.compose(linalg_map) → f(g(x)) + AffineMap composedMap = submapOp.getMap().compose(map); + newIndexingMaps.push_back(composedMap); + } else { + newOutputs.push_back(output); + newIndexingMaps.push_back(map); + } + } + + if (!hasSubmaps) { + return failure(); + } + + // Create new linalg.generic with composed maps + auto newGenericOp = rewriter.create( + genericOp.getLoc(), + genericOp.getResultTypes(), + newInputs, + newOutputs, + newIndexingMaps, + genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr); + + // Clone the region + IRMapping mapping; + genericOp.getRegion().cloneInto(&newGenericOp.getRegion(), mapping); + + rewriter.replaceOp(genericOp, newGenericOp.getResults()); + return success(); + } +}; + +void polygeist::SubmapOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + // results.insert(context); + results.insert(context); + // results.insert(context); +} + diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index d6947a1931c5..07a559ae00e8 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -11,7 +11,10 @@ add_mlir_dialect_library(MLIRPolygeistTransforms OpenMPOpt.cpp BarrierRemovalContinuation.cpp RaiseToAffine.cpp + RemoveIterArgs.cpp RaiseToLinalg.cpp + LinalgDebufferize.cpp + LinalgToKernel.cpp ParallelLower.cpp TrivialUse.cpp ConvertPolygeistToLLVM.cpp @@ -43,15 +46,18 @@ add_mlir_dialect_library(MLIRPolygeistTransforms MLIRGPUToNVVMTransforms MLIRIR MLIRLLVMDialect + MLIRLinalgDialect MLIRMathDialect MLIRMathToLLVM MLIRMemRefDialect MLIRNVVMDialect MLIRPass MLIRPolygeist + MLIRPolygeistKernel MLIRSideEffectInterfaces MLIRSCFToControlFlow MLIRTargetLLVMIRImport + MLIRTensorDialect MLIRTransformUtils MLIRGPUToROCDLTransforms MLIRControlFlowToLLVM diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp new file mode 100644 index 000000000000..1a4e22e39dec --- /dev/null +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -0,0 +1,752 @@ +#include "PassDetails.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "polygeist/Ops.h" +#include "polygeist/Passes/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-debufferize" + +using namespace mlir; +using namespace mlir::arith; +using namespace polygeist; +using namespace affine; +using namespace linalg; +using namespace tensor; +using namespace bufferization; + +using opTuple = std::tuple; //First: result, Second: prev_tensor ? + +bool isCaptured(Value v, Operation *potentialUser = nullptr, + bool *seenuse = nullptr); + +bool isAncestor(Operation *potentialAncestor, Operation *op) { + Operation *current = op->getParentOp(); + while (current != nullptr) { + if (current == potentialAncestor) + return true; + current = current->getParentOp(); + } + return false; +} + +//Checks if a comes before b +bool comesBefore(Operation *a, Operation *b) { + if (a == b) return false; + + if (isAncestor(a, b)) return true; + if (isAncestor(b, a)) return false; + + Operation *aParent = a->getParentOp(); + Operation *bParent = b->getParentOp(); + // Walk up b's hierarchy until we reach a's level + Operation *bAncestor = b; + //We traverse B's ancestors here + while (Operation *parent = bAncestor->getParentOp()) { + if (parent == aParent) { + // Compare positions within aParent's regions/blocks + Region *aRegion = a->getParentRegion(); + Region *bRegion = bAncestor->getParentRegion(); + + if (aRegion == bRegion) { + // Same region: compare block order + Block *aBlock = a->getBlock(); + Block *bBlock = bAncestor->getBlock(); + if (aBlock != bBlock) { + auto get_block_pos = [](Region *region, Block *block) { + auto &blocks = region->getBlocks(); + auto it = llvm::find_if(blocks, [block](Block &b) { + return &b == block; // Address comparison + }); + assert(it != blocks.end() && "Block not found in region"); + return std::distance(blocks.begin(), it); + }; + return get_block_pos(aRegion, aBlock) < + get_block_pos(bRegion, bBlock); + }; + // Same block: compare operation order + return a->isBeforeInBlock(bAncestor); + } + + // Different regions: compare region order + auto compareRegions = [parent](Region *x, Region *y) { + auto get_region_position = [](Operation *parent, Region *target) { + auto regions = parent->getRegions(); // Get reference to region list + auto begin = regions.begin(); + auto it = llvm::find_if(regions, [&](Region &r) { + return &r == target; + }); + return std::distance(begin, it); + }; + return get_region_position(parent, x) < + get_region_position(parent, y); + }; + return compareRegions(aRegion, bRegion); + } + bAncestor = parent; + } + + Operation *aAncestor = a; + //We traverse A's ancestors here + while (Operation *parent = aAncestor->getParentOp()) { + if (parent == bParent) { + // Compare positions within aParent's regions/blocks + Region *bRegion = b->getParentRegion(); + Region *aRegion = aAncestor->getParentRegion(); + + if (aRegion == bRegion) { + // Same region: compare block order + Block *bBlock = b->getBlock(); + Block *aBlock = aAncestor->getBlock(); + if (aBlock != bBlock) { + auto get_block_pos = [](Region *region, Block *block) { + auto &blocks = region->getBlocks(); + auto it = llvm::find_if(blocks, [block](Block &b) { + return &b == block; // Address comparison + }); + assert(it != blocks.end() && "Block not found in region"); + return std::distance(blocks.begin(), it); + }; + return !(get_block_pos(bRegion, bBlock) < + get_block_pos(aRegion, aBlock)); + }; + // Same block: compare operation order + return !b->isBeforeInBlock(aAncestor); + } + + // Different regions: compare region order + auto compareRegions = [parent](Region *x, Region *y) { + auto get_region_position = [](Operation *parent, Region *target) { + auto regions = parent->getRegions(); // Get reference to region list + auto begin = regions.begin(); + auto it = llvm::find_if(regions, [&](Region &r) { + return &r == target; + }); + return std::distance(begin, it); + }; + return get_region_position(parent, x) < + get_region_position(parent, y); + }; + return !compareRegions(bRegion, aRegion); + } + aAncestor = parent; + } + + //llvm_unreachable("Operations do not share a common ancestor"); + //// Recursive case: compare parent operations + return comesBefore(aParent, bParent); +} + +std::vector getSortedUsers(Value val) { + std::vector users; + for (Operation *user : val.getUsers()) { + //This logic is to prevent duplication of users + auto it = std::find_if(users.begin(), users.end(), + [user](const Operation* op) { + return op == user; + }); + if(it == users.end()) + users.push_back(user); + } + + std::sort(users.begin(), users.end(), [](Operation *a, Operation *b) { + return comesBefore(a,b); + }); + + return users; +} + +// std::vector getSortedUsers(Operation *op) { +// // Find the parent function +// auto funcOp = op->getParentOfType(); +// if (!funcOp) +// return {}; + +// // Map to store order of operations +// llvm::DenseMap opOrder; +// size_t order = 0; + +// funcOp.walk([&](Operation *curOp) { opOrder[curOp] = order++; }); + +// std::vector sortedUsers(op->getUsers().begin(), +// op->getUsers().end()); + +// std::sort( +// sortedUsers.begin(), sortedUsers.end(), +// [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; }); + +// return sortedUsers; +// } + +Region* findCommonAncestorRegion(Operation* a, Operation* b) { + DenseMap regionCounts; + + // Walk up from operation A + Operation* currentOp = a; + while (Region* region = currentOp->getParentRegion()) { + regionCounts[region]++; + currentOp = region->getParentOp(); + } + + // Walk up from operation B to find common region + currentOp = b; + while (Region* region = currentOp->getParentRegion()) { + if (regionCounts.count(region)) + return region; + currentOp = region->getParentOp(); + } + return nullptr; +} + + +struct debufferizationAllocaRemoval : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AllocaOp allocaOp, + PatternRewriter &rewriter) const final { + Value allocaResult = allocaOp.getResult(); + bool userToTensorOp = false; + bool userCopyOp = false; + bool userOtherOp = false; + memref::CopyOp copyOp; + bufferization::ToTensorOp toTensorOp; + for (Operation *user : allocaResult.getUsers()) { + if (isa(user)) { + userToTensorOp = true; + toTensorOp = cast(user); + } + else if (isa(user)) { + userCopyOp = true; + copyOp = cast(user); + } + else + userOtherOp = true; + } + + if(!(!userOtherOp&&userCopyOp&&userToTensorOp)) + return failure(); + + auto emptyTensor = + rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), + allocaOp.getType().getElementType()); + + rewriter.replaceAllUsesWith(toTensorOp.getResult(), emptyTensor.getResult()); + + rewriter.eraseOp(copyOp); + rewriter.eraseOp(toTensorOp); + return success(); + } +}; + +void findUsersInRegion( + mlir::Value value, + mlir::Region& region, + llvm::SmallVectorImpl& users +) { + for (mlir::Block& block : region) { + for (mlir::Operation& op : block) { + for (mlir::Value operand : op.getOperands()) { + if (operand == value) { + users.push_back(&op); + break; // No need to check other operands for this op + } + } + + // Recursively check all sub-regions of this operation + for (mlir::Region& subRegion : op.getRegions()) { + findUsersInRegion(value, subRegion, users); + } + } + } +} + +void propagateValueThroughRegion(Value ¤tValue, SmallVector regions, std::vector expandedUserList, llvm::DenseMap opResultMap, PatternRewriter &rewriter) { + auto module = currentValue.getDefiningOp()->getParentOfType(); + for (Region* region : regions) { + Block& block = region->front(); + Operation* terminator = block.getTerminator(); + Operation *parentOp = region->getParentOp(); + + //Find init Tensor for the given for loop, i.e first match to expanded user list + mlir::Value initTensor; + int insertIdx = 0; + bool insertIdxFound = false; + for(auto user: expandedUserList) { + mlir::Region *opRegion = user->getParentRegion(); + if(region->isAncestor(opRegion)) { + insertIdxFound = true; + //Maintain a map data structure for tracking every user and if they have been processed then the corresponding result + auto it = opResultMap.find(user); + if(it == opResultMap.end()) + continue; + auto keys_value = it->second; + auto op_result = std::get<0>(keys_value); + initTensor = std::get<1>(keys_value); + break; + } + if(!insertIdxFound) + insertIdx++; + } + + //Compare use Values with + + if( auto prevIf = dyn_cast_or_null(parentOp)) { + auto prevResults = prevIf.getResults(); + SmallVector newResultTypes; + for (auto res : prevResults) + newResultTypes.push_back(res.getType()); + newResultTypes.push_back(currentValue.getType()); + + // Yield original results + new value + auto thenYieldArgs = prevIf.thenYield().getOperands(); + SmallVector thenYieldValues; + for (const auto &it :thenYieldArgs) { + thenYieldValues.push_back(it); + } + thenYieldValues.push_back(currentValue); + + SmallVector elseYieldValues; + if(!prevIf.getElseRegion().empty()){ + auto elseYieldArgs = prevIf.elseYield().getOperands(); + for (const auto &it :elseYieldArgs) { + elseYieldValues.push_back(it); + } + } + elseYieldValues.push_back(initTensor); + + //Create new Ifop + rewriter.setInsertionPoint(prevIf); + auto newIf = rewriter.create(prevIf.getLoc(), + newResultTypes, // Combined types + prevIf.getCondition(), // New condition value + true + ); + if (newIf.thenBlock()) + rewriter.eraseBlock(newIf.thenBlock()); + + newIf.getThenRegion().takeBody(prevIf.getThenRegion()); + if(!prevIf.getElseRegion().empty()) + newIf.getElseRegion().takeBody(prevIf.getElseRegion()); + + + //Update yield ops + rewriter.setInsertionPointToEnd(newIf.thenBlock()); + rewriter.replaceOpWithNewOp(newIf.thenYield(), thenYieldValues); + if(!prevIf.getElseRegion().empty()) { + rewriter.setInsertionPointToEnd(newIf.elseBlock()); + rewriter.replaceOpWithNewOp(newIf.elseYield(), elseYieldValues); + } else { + rewriter.setInsertionPointToEnd(newIf.elseBlock()); + rewriter.create(newIf.getLoc(), elseYieldValues); + } + + //TODO: need to update results of prevIf and else with the new ones + opResultMap[newIf] = std::make_tuple(newIf->getResult(newIf->getNumResults() - 1), currentValue); + currentValue = newIf->getResult(newIf->getNumResults() - 1); + + } + else if (auto prevFor = dyn_cast_or_null(parentOp)) { + + //After first match, now find all the users of the init Tensor in a region. + llvm::SmallVector initOpUsers; + findUsersInRegion(initTensor, *region, initOpUsers); + + SmallVector newInitOperands = prevFor.getInitArgs(); + newInitOperands.push_back(initTensor); //Needs to be the earliest use inside the region. + //TODO: Does this require fix in if as well? + + SmallVector newResultTypes(prevFor.getResultTypes().begin(), prevFor.getResultTypes().end()); + newResultTypes.push_back(currentValue.getType()); + + rewriter.setInsertionPoint(prevFor); + scf::ForOp newLoop = rewriter.create( + prevFor.getLoc(), + prevFor.getLowerBound(), + prevFor.getUpperBound(), + prevFor.getStep(), + newInitOperands + ); + newLoop->setAttrs(prevFor.getOperation()->getAttrs()); + + // Create block with induction variable + original args + new arg + SmallVector blockArgTypes; + blockArgTypes.push_back(newLoop.getInductionVar().getType()); // IV + llvm::append_range(blockArgTypes, newLoop.getResultTypes()); // Original args + + // Transfer operations from original block to new block + Block *newBlock = &newLoop.getRegion().front(); + Block *originalBlock = &prevFor.getRegion().front(); + newBlock->getOperations().splice( + newBlock->end(), + originalBlock->getOperations() + ); + + // Replace uses of original block arguments with new ones + for (unsigned i = 0; i < originalBlock->getNumArguments()-1; ++i) { + originalBlock->getArgument(i + 1) // +1 for IV + .replaceAllUsesWith(newBlock->getArgument(i + 1)); + } + + auto yieldOp = cast(newBlock->getTerminator()); + SmallVector newYieldValues = yieldOp.getOperands(); + // Add new iteration arg from block arguments + newYieldValues.push_back(currentValue); + + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); + + //Update users of initOp to use iterArgs + for(auto initOpUser: initOpUsers) { + // Iterate over all operands (both inputs and outputs) + for (const auto &en : llvm::enumerate(initOpUser->getOperands())) { + if (en.value() == initTensor) { + OpOperand &operand = initOpUser->getOpOperand(en.index()); + Value newValue = newLoop.getRegionIterArg(newLoop.getRegion().front().getNumArguments()-2); //-1 for IV + operand.set(newValue); + } + } + } + + //Update users of prev For loops results + for (auto [oldResult, newResult] : llvm::zip(prevFor.getResults(), newLoop.getResults().drop_back())) { + oldResult.replaceAllUsesWith(newResult); + } + rewriter.eraseOp(prevFor); + currentValue = newLoop.getResults().back(); + + //Store this in the user list for this region, need to create a data structure for users + opResultMap[newLoop] = std::make_tuple(currentValue, initTensor); + //Update the user list with the for Loop + expandedUserList.insert(expandedUserList.begin() + insertIdx, newLoop); + } + } +} + +bool isDirectUser(Operation *consumer, Operation *producer) { + for (Value operand : consumer->getOperands()) { + if (operand.getDefiningOp() == producer) + return true; + } + return false; +} + +// Problems with this implementation: The way this implementation works is by jumping over users +// of alloca/args. The users we get are not in sorted order. We write a function to sort out the users across +// regions, blocks and ops as long as they lie in the same ancestry. +// Now as we update an op, and use the output tensor to give input to the next op- it works fine for simple cases with no region. +// But things becomes more complicated when we have nested regions like in scf.if and scf.for ops +// Why? Because we need to update scf.if and scf.for ops to yield correct tensors to be used by the next user. +// So how to do it? Well the best way is to traverse all the IR in a walk and and as we encouter a user and it's linalg.generic then we update +// it's params to tensor and generate an output tensor if it can, and move to the next op and repeat this until we encounter an end of region. +// At this point we need to decide if we need to yield the tensor or not? This depends if there is an external user of the original arg/alloca +// still left over. I think this can be done by tracking users of an op, and eliminating the ones which have been used. +// In the current way it's done- we can go the next user and check if the previous user is in the same block if not we need to propagate the previous +// users output tensor through regions with yield. +// How does this work if the user is not actually outputing data, that means it didn't generate an output tensor. In which case the original tensor needs to be continued. +// In current flow, we are tracking updated output tensor, now we can iteratively yield the value until it reaches the same block as next user. +struct LinalgDebufferization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::FuncOp funcOp, + PatternRewriter &rewriter) const final { + + auto module = funcOp->getParentOfType(); + + //SmallVector opsToDelete; + //llvm::SmallPtrSet opsToDeleteSet; + // Tracks both old linalg.generics and linalg.generics with repeated values + // in ins and outs + + LogicalResult passResult = failure(); + + auto handleMemref = [&](Value memVal) -> LogicalResult { + llvm::SmallPtrSet processedGenericOps; + auto module = memVal.getParentRegion()->getParentOfType(); + + if (!memVal.getType().isa()) { + return failure(); + } + + bool isNoalias = false; + if (auto mem = memVal.getDefiningOp()) { + if (auto defOp = memVal.getDefiningOp()) {//if (mem has allocation like) { + if (isa(defOp)) { + isNoalias = true; + } + } + } else if (auto ba = dyn_cast(memVal)) { + if (auto fn = dyn_cast(ba.getOwner()->getParentOp())) { + if (fn.getArgAttr(ba.getArgNumber(), LLVM::LLVMDialect::getNoAliasAttrName())) { + isNoalias = true; + } + } + } else if (memVal.getDefiningOp() || + memVal.getDefiningOp()) { + isNoalias = true; //TODO: is this correct? + } + + // if we are no alias we can just look at all users of the value + // if we are not noalias, or we are captured, then we have to look at all users that + // could read or write + //TODO: skipping noalias for now + //if ((!isNoalias) || isCaptured(memVal)) { + // return failure(); + //} + + MemRefType memrefType; + if (auto blockArg = memVal.dyn_cast()) { + memrefType = blockArg.getType().dyn_cast(); + } else if (auto allocaOp = memVal.getDefiningOp()) { + memrefType = allocaOp.getType(); + } else if (auto allocOp = memVal.getDefiningOp()) { + memrefType = allocOp.getType(); + } else { + return failure(); + } + + + rewriter.setInsertionPointAfterValue(memVal); + auto tensorType = RankedTensorType::get( + memrefType.getShape(), memrefType.getElementType()); + + // Check to see if only linalg.generic are users of the Value op for now. + //// TODO: Extend this + //if (!llvm::all_of(memVal.getUsers(), [](Operation *op) { + // return isa(op) || isa(op); + // })) { + // return failure(); + //} + + // auto emptyTensor = + // rewriter.create(allocaOp.getLoc(),allocaOp.getType().getShape(), + // allocaOp.getType().getElementType()); + auto sortedUsers = getSortedUsers(memVal); + + // If the first user is already a to_tensor op, don't try to debufferize + if (!sortedUsers.empty() && isa(sortedUsers[0])) { + return failure(); + } + + auto toTensorOp = rewriter.create( + memVal.getLoc(), tensorType, memVal); + Value currentTensor = toTensorOp; + + + //Other algorithm: + // 1. Walk over all ops + // 2. If you find a directUser - function defined then do the things for sortedUsers + // 3. If you encounter region based ops, like scf.for op and scf.if op, then track the + // op to be used for yield in scf.if + // For scf.for track the the op to be used for init, as well as the op to be updated by init. + // Op to be used by yield comes at the end. + // Problem walk.break will break things and won't be able to track recursive stuff - so would have to restart every time! + + //Variables to track results and init value with an operation that has been changed to tensor from memref + llvm::DenseMap opResultMap; + + + // Check if allocaOp is an output in current genericOp + std::vector expandedUserList(sortedUsers); + int userIdx = 0; + for (auto user : sortedUsers) { + if (auto genericOp = dyn_cast(user)) { + + // auto genericOp = cast(user); + //if (processedGenericOps.count(genericOp) > 0) + // continue; + rewriter.setInsertionPointAfter(genericOp); + + SmallVector newInputs; + SmallVector newOutputs; + SmallVector resultTypes; + // Create a new linalg.generic in Destination Style Passing format + + //check_if_current_tensor_is_available_to_user_if_not_propagate_to_scope() { + // extract_common_ancestor of curentTensor and userOp. + // propagte currentTensor all the way to common ancestor. + // Make the propagated value the current tensor. + //} + auto commonRegion = findCommonAncestorRegion(currentTensor.getDefiningOp(), user); + if (!commonRegion) return failure(); + // Collect regions from source to common ancestor + SmallVector regions; + for (Region* r = currentTensor.getParentRegion(); r != commonRegion; + r = r->getParentOp()->getParentRegion()) { + regions.push_back(r); + } + + // Propagate value through each region + propagateValueThroughRegion(currentTensor, regions, expandedUserList, opResultMap, rewriter); + + ArrayAttr indexingMaps = genericOp.getIndexingMaps(); + for (auto input : genericOp.getInputs()) { + newInputs.push_back(input == memVal ? currentTensor : input); + } + + // ArrayRef resultTypes; + int newCurrentTensorIndex = -1; + int index = 0; + for (auto output : genericOp.getOutputs()) { + newOutputs.push_back(output == memVal ? currentTensor : output); + resultTypes.push_back(output == memVal ? currentTensor.getType() + : output.getType()); + if (output == memVal) { + newCurrentTensorIndex = index; + } + index++; + } + + rewriter.setInsertionPointAfter(genericOp); + StringAttr empty = StringAttr::get(genericOp.getContext()); + ArrayRef resultTypesRef(resultTypes); + auto newGenericOp = rewriter.create( + genericOp.getLoc(), resultTypesRef, newInputs, newOutputs, + genericOp.getIndexingMaps(), genericOp.getIteratorTypes(), empty, + empty); + + Region &opRegion = newGenericOp.getRegion(); + rewriter.cloneRegionBefore(genericOp.getRegion(), + newGenericOp.getRegion(), + newGenericOp.getRegion().end()); + + // Replace all uses of original generic op with the new one + for (unsigned i = 0; i < genericOp->getNumResults(); ++i) { + genericOp->getResult(i).replaceAllUsesWith( + newGenericOp->getResult(i)); + } + + // Delete the original genericOp + if (newCurrentTensorIndex != -1){ + opResultMap[newGenericOp] = std::make_tuple(newGenericOp.getResult(newCurrentTensorIndex), currentTensor); + currentTensor = newGenericOp.getResult(newCurrentTensorIndex); + } + + rewriter.eraseOp(genericOp); + //Updated expanded user list, as this op is deleted + expandedUserList.insert(expandedUserList.begin() + userIdx, newGenericOp); + userIdx++; + expandedUserList.erase(expandedUserList.begin() + userIdx); + + } + else if (auto subviewOp = dyn_cast(user)) { + if (subviewOp.getSource() == memVal) { + // Convert memref.subview to tensor.extract_slice + rewriter.setInsertionPointAfter(subviewOp); + auto extractSliceOp = rewriter.create( + subviewOp.getLoc(), + currentTensor, // Use the tensor version + subviewOp.getOffsets(), + subviewOp.getSizes(), + subviewOp.getStrides()); + + // This creates a new tensor that can be used by subsequent operations + // Need to handle this tensor in the debufferization chain + } + } + } + + //For adding yields for the last use all the way to the outer most region + auto commonRegion = findCommonAncestorRegion(currentTensor.getDefiningOp(), toTensorOp); + if (!commonRegion) return failure(); + // Collect regions from source to common ancestor + SmallVector regions; + for (Region* r = currentTensor.getParentRegion(); r != commonRegion; + r = r->getParentOp()->getParentRegion()) { + regions.push_back(r); + } + + propagateValueThroughRegion(currentTensor, regions, expandedUserList, opResultMap, rewriter); + + //if(!regions.empty()) { + // auto lastRegion = regions.back(); + // Operation *parentOp = lastRegion->getParentOp(); + // rewriter.setInsertionPointAfter(parentOp); + //} + //if(currentTensor != prevTensor) { + + // Only insert to_memref and copy if currentTensor was actually transformed + if (currentTensor != toTensorOp) { + rewriter.setInsertionPointAfter(currentTensor.getDefiningOp()); + auto toMemrefOp = rewriter.create( + memVal.getLoc(), memrefType, currentTensor); + rewriter.create(memVal.getLoc(), toMemrefOp, memVal); + } + //} + // opsToDelete.push_back(allocaOp.getOperation()); + return success(); + }; + + + bool anySuccess = false; + //Fix instead of walk, just get the list of allocaOp users, so that you can easily delete ops inside + SmallVector listOfAllocaOps; + SmallVector listOfAllocOps; + + funcOp.walk([&](memref::AllocaOp alloca) { + listOfAllocaOps.push_back(alloca); + }); + //TODO: Adding allocOp for now, without alias check + funcOp.walk([&](memref::AllocOp alloc) { + listOfAllocOps.push_back(alloc); + }); + + for (auto alloca : listOfAllocaOps) { + anySuccess |= succeeded(handleMemref(alloca)); + } + + for (auto alloc : listOfAllocOps) { + anySuccess |= succeeded(handleMemref(alloc)); + } + + for(auto arg: funcOp.getArguments()){ + anySuccess |= succeeded(handleMemref(arg)); + } + + passResult = anySuccess ? success() : failure(); + //for (Operation *op : opsToDelete) { + // op->erase(); + //} + //opsToDelete.clear(); + + return passResult; + } +}; + +namespace { +struct LinalgDebufferize : public LinalgDebufferizeBase { + void runOnOperation() override; +}; +} // namespace + +void LinalgDebufferize::runOnOperation() { + auto module = getOperation()->getParentOfType(); + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); +} + +namespace mlir { +namespace polygeist { +std::unique_ptr createLinalgDebufferizePass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir diff --git a/lib/polygeist/Passes/LinalgToKernel.cpp b/lib/polygeist/Passes/LinalgToKernel.cpp new file mode 100644 index 000000000000..3563c0ae4731 --- /dev/null +++ b/lib/polygeist/Passes/LinalgToKernel.cpp @@ -0,0 +1,765 @@ +//===- LinalgToKernel.cpp - Pattern to match linalg.generic with kernel.defn ------===// +// +// This file implements a pattern to rewrite linalg.generic operations to kernel +// operations by matching against patterns defined in kernel.defn_collection. +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/FileUtilities.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/Debug.h" +#include "polygeist/Kernel/KernelDialect.h" +#include "polygeist/Kernel/KernelOps.h" +#include "polygeist/Passes/Passes.h" + +#include +#include +#include + +#define DEBUG_TYPE "linalg-to-kernel" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::polygeist; +using namespace mlir::polygeist::kernel; + +namespace { + +// Structure to represent an operation node in the dependency graph +struct OpNode { + Operation *op; + StringRef opName; + SmallVector operandTypes; + SmallVector resultTypes; + SmallVector dependencies; // Operations this depends on + SmallVector dependents; // Operations that depend on this + + OpNode(Operation *operation) : op(operation) { + if (operation) { + // Regular operation node + opName = operation->getName().getStringRef(); + for (Value operand : operation->getOperands()) { + operandTypes.push_back(operand.getType()); + } + for (Value result : operation->getResults()) { + resultTypes.push_back(result.getType()); + } + } else { + // Special node for block arguments - will be set later + opName = "block_arg"; + } + } + + // Check if two nodes are structurally equivalent (same operation type and types) + bool isEquivalentTo(const OpNode &other) const { + return opName == other.opName && + operandTypes == other.operandTypes && + resultTypes == other.resultTypes; + } +}; + +// Structure to represent a dependency graph for a region +struct DependencyGraph { + SmallVector> nodes; + DenseMap opToNode; + SmallVector blockArgNodes; // Special nodes for block arguments + + void buildFromRegion(Region ®ion) { + // Process each block in the region + for (Block &block : region.getBlocks()) { + + // Create pseudo-nodes for block arguments + for (BlockArgument arg : block.getArguments()) { + // Block arguments are represented as special nodes + auto argNode = std::make_unique(nullptr); + argNode->resultTypes.push_back(arg.getType()); + blockArgNodes.push_back(argNode.get()); + + // Map the block argument value to this node for dependency tracking + // We'll use a separate map for this + nodes.push_back(std::move(argNode)); + } + + // Create nodes for each operation + for (Operation &op : block.getOperations()) { + auto node = std::make_unique(&op); + OpNode *nodePtr = node.get(); + opToNode[&op] = nodePtr; + nodes.push_back(std::move(node)); + } + + // Build dependency edges + for (Operation &op : block.getOperations()) { + OpNode *currentNode = opToNode[&op]; + + // For each operand, find what it depends on + for (Value operand : op.getOperands()) { + if (auto blockArg = dyn_cast(operand)) { + // Depends on a block argument + size_t argIndex = blockArg.getArgNumber(); + if (argIndex < blockArgNodes.size()) { + OpNode *argNode = blockArgNodes[argIndex]; + currentNode->dependencies.push_back(argNode); + argNode->dependents.push_back(currentNode); + } + } else if (Operation *definingOp = operand.getDefiningOp()) { + // Depends on another operation + if (opToNode.count(definingOp)) { + OpNode *depNode = opToNode[definingOp]; + currentNode->dependencies.push_back(depNode); + depNode->dependents.push_back(currentNode); + } + } + } + } + } + } + + // Get nodes in topological order (dependencies first) + SmallVector getTopologicalOrder() const { + SmallVector result; + DenseSet visited; + + std::function dfs = [&](OpNode* node) { + if (visited.contains(node)) return; + visited.insert(node); + + // Visit all dependencies first + for (OpNode* dep : node->dependencies) { + dfs(dep); + } + + result.push_back(node); + }; + + // Start DFS from all nodes + for (const auto &node : nodes) { + dfs(node.get()); + } + + return result; + } +}; + +// Enhanced region equivalence check using dependency graphs +bool areRegionsEquivalent(Region &first, Region &second, DenseMap &nodeMapping, + DenseMap &operationMapping) { + // Clear the output mappings + nodeMapping.clear(); + operationMapping.clear(); + + // Fast early checks before expensive graph construction + + // Check number of blocks + if (first.getBlocks().size() != second.getBlocks().size()) { + return false; + } + + // Check each block's basic properties + for (auto blockPair : llvm::zip(first.getBlocks(), second.getBlocks())) { + Block &firstBlock = std::get<0>(blockPair); + Block &secondBlock = std::get<1>(blockPair); + + // Check number of arguments + if (firstBlock.getNumArguments() != secondBlock.getNumArguments()) { + return false; + } + + // Check argument types + for (auto argPair : llvm::zip(firstBlock.getArguments(), secondBlock.getArguments())) { + if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType()) { + return false; + } + } + + // Check number of operations + if (firstBlock.getOperations().size() != secondBlock.getOperations().size()) { + return false; + } + } + + // If basic checks pass, proceed with detailed graph-based analysis + // Build dependency graphs for both regions + DependencyGraph firstGraph, secondGraph; + firstGraph.buildFromRegion(first); + secondGraph.buildFromRegion(second); + + // Quick structural checks + if (firstGraph.nodes.size() != secondGraph.nodes.size()) { + return false; + } + + if (firstGraph.blockArgNodes.size() != secondGraph.blockArgNodes.size()) { + return false; + } + + // Get topological orderings + auto firstOrder = firstGraph.getTopologicalOrder(); + auto secondOrder = secondGraph.getTopologicalOrder(); + + if (firstOrder.size() != secondOrder.size()) { + return false; + } + + // Compare nodes in topological order and build mapping + for (size_t i = 0; i < firstOrder.size(); ++i) { + OpNode *firstNode = firstOrder[i]; + OpNode *secondNode = secondOrder[i]; + + // Check if the nodes are structurally equivalent + if (!firstNode->isEquivalentTo(*secondNode)) { + return false; + } + + // Check if dependency structure matches + if (firstNode->dependencies.size() != secondNode->dependencies.size()) { + return false; + } + + // Verify that dependencies map correctly + for (size_t j = 0; j < firstNode->dependencies.size(); ++j) { + OpNode *firstDep = firstNode->dependencies[j]; + OpNode *secondDep = secondNode->dependencies[j]; + + // Check if we've established a mapping for these dependencies + auto it = nodeMapping.find(firstDep); + if (it != nodeMapping.end()) { + if (it->second != secondDep) { + return false; // Inconsistent mapping + } + } else { + nodeMapping[firstDep] = secondDep; + } + } + + // Establish mapping for current nodes + nodeMapping[firstNode] = secondNode; + + // Build the operation mapping directly from OpNode data while still valid + if (firstNode->op && secondNode->op) { + operationMapping[firstNode->op] = secondNode->op; + } + } + + return true; +} + +// Helper to check if indexing maps are equivalent +bool areIndexingMapsEquivalent(ArrayAttr firstMaps, ArrayAttr secondMaps) { + if (firstMaps.size() != secondMaps.size()) + return false; + + for (auto mapPair : llvm::zip(firstMaps, secondMaps)) { + auto firstMap = std::get<0>(mapPair).cast().getValue(); + auto secondMap = std::get<1>(mapPair).cast().getValue(); + + if (firstMap != secondMap) + return false; + } + + return true; +} + +// Helper to check if iterator types are equivalent +bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) { + if (firstTypes.size() != secondTypes.size()) + return false; + + for (auto typePair : llvm::zip(firstTypes, secondTypes)) { + auto firstType = std::get<0>(typePair).cast().getValue(); + auto secondType = std::get<1>(typePair).cast().getValue(); + + if (firstType != secondType) + return false; + } + + return true; +} + +// Helper function to find the corresponding value in actual IR for a kernel block argument +Value findCorrespondingValue(BlockArgument kernelArg, + const DenseMap &operationMapping, + GenericOp genericOp) { + + LLVM_DEBUG(llvm::dbgs() << "Finding corresponding value for kernel arg #" << kernelArg.getArgNumber() + << " with type " << kernelArg.getType() << "\n"); + + // First, check if this kernel argument is used as an operand to the linalg.generic itself + // This handles tensor arguments that become ins/outs operands + for (Operation *kernelUser : kernelArg.getUsers()) { + LLVM_DEBUG(llvm::dbgs() << "Kernel arg used by: " << *kernelUser << "\n"); + + // Check if the user is a linalg.generic operation + if (auto kernelGeneric = dyn_cast(kernelUser)) { + LLVM_DEBUG(llvm::dbgs() << "Kernel arg is used by linalg.generic as operand\n"); + + // Find which operand position kernelArg occupies in the kernel's linalg.generic + size_t operandIndex = 0; + for (Value operand : kernelGeneric->getOperands()) { + if (operand == kernelArg) { + LLVM_DEBUG(llvm::dbgs() << "Kernel arg is at operand index " << operandIndex + << " of kernel linalg.generic\n"); + + // The corresponding operand in the actual linalg.generic should be at the same position + if (operandIndex < genericOp->getNumOperands()) { + Value actualOperand = genericOp->getOperand(operandIndex); + LLVM_DEBUG(llvm::dbgs() << "Found corresponding actual operand: " << actualOperand << "\n"); + return actualOperand; + } else { + LLVM_DEBUG(llvm::dbgs() << "ERROR - operand index out of bounds in actual generic\n"); + } + break; + } + operandIndex++; + } + + // If we found a linalg.generic usage, we're done with this user + break; + } + } + + // If we reach here, this might be a scalar argument used inside the region + // For scalar arguments like %arg3, %arg4, use operation mapping to trace usage + LLVM_DEBUG(llvm::dbgs() << "Checking if kernel arg is a scalar used inside region\n"); + + for (Operation *kernelUser : kernelArg.getUsers()) { + // Skip if this is the linalg.generic itself (already handled above) + if (isa(kernelUser)) continue; + + LLVM_DEBUG(llvm::dbgs() << "Kernel arg used by operation: " << *kernelUser << "\n"); + + // Find the corresponding operation in actual IR using the fixed mapping + // Note: operationMapping is actualOp -> kernelOp, so we need to reverse-search + auto it = std::find_if(operationMapping.begin(), operationMapping.end(), + [kernelUser](const auto& pair) { + return pair.second == kernelUser; + }); + if (it != operationMapping.end()) { + Operation *actualUser = it->first; // The actual IR operation + LLVM_DEBUG(llvm::dbgs() << "Found corresponding actual operation: " << *actualUser << "\n"); + + // Find which operand position kernelArg occupies in kernelUser + size_t operandIndex = 0; + for (Value operand : kernelUser->getOperands()) { + if (operand == kernelArg) { + LLVM_DEBUG(llvm::dbgs() << "Kernel arg is at operand index " << operandIndex << "\n"); + + // Get the corresponding operand from actual IR + if (operandIndex < actualUser->getNumOperands()) { + Value actualOperand = actualUser->getOperand(operandIndex); + LLVM_DEBUG(llvm::dbgs() << "Found corresponding actual operand: " << actualOperand << "\n"); + return actualOperand; + } else { + LLVM_DEBUG(llvm::dbgs() << "ERROR - operand index out of bounds\n"); + } + break; + } + operandIndex++; + } + } else { + LLVM_DEBUG(llvm::dbgs() << "Could not find corresponding operation in operationMapping\n"); + } + } + + // Fallback: if operation mapping fails, try type matching as last resort + LLVM_DEBUG(llvm::dbgs() << "Fallback to type matching for function arguments\n"); + + auto func = genericOp->getParentOfType(); + if (func) { + LLVM_DEBUG(llvm::dbgs() << "Found parent function with " << func.getNumArguments() << " arguments\n"); + + // Look for function arguments with matching type + for (auto funcArg : func.getArguments()) { + if (funcArg.getType() == kernelArg.getType()) { + LLVM_DEBUG(llvm::dbgs() << "Found function argument with matching type: " << funcArg << "\n"); + // TODO: This is still not ideal - should be improved with better analysis + return funcArg; + } + } + } + + LLVM_DEBUG(llvm::dbgs() << "ERROR - Could not find corresponding value for kernel arg\n"); + return nullptr; +} + +// Structure to hold the result of matching a generic operation with a kernel definition +struct KernelMatchResult { + StringRef kernelName; + DenseMap operationMapping; // actual op -> kernel op + kernel::DefnOp matchedDefnOp; +}; + +// Check if a linalg.generic operation matches a kernel.defn in a collection +FailureOr matchGenericWithDefn( + GenericOp genericOp, + kernel::DefnCollectionOp collectionOp) { + + // Get attributes from the generic operation + ArrayAttr indexingMaps = genericOp.getIndexingMapsAttr(); + ArrayAttr iteratorTypes = genericOp.getIteratorTypesAttr(); + unsigned numInputs = genericOp.getNumDpsInputs(); + unsigned numOutputs = genericOp.getNumDpsInits(); + + // Variables to capture the match result + StringRef matchedOpName; + DenseMap matchedOperationMapping; + kernel::DefnOp matchedDefnOp; + + SmallVector defnOps; + + //llvm::errs() << "DEBUG: kernel.defn_collection contents:\n"; + //llvm::errs() << collectionOp; + //llvm::errs() << collectionOp.getOperation(); + //llvm::errs() << "\n"; + collectionOp.walk([&](kernel::DefnOp defnOp) { + defnOps.push_back(defnOp); + }); + + bool foundMatch = false; + + // Walk through each defn in the collection + for (auto defnOp : defnOps) { + + StringRef opName = defnOp.getSymName(); + LLVM_DEBUG(llvm::dbgs() << "Checking kernel defn: " << opName << "\n"); + + // Check for linalg.generic in the defn's body + GenericOp candidateOp; + + defnOp.walk([&](GenericOp genericOp) { + candidateOp = genericOp; //TODO: Add checks to make sure there is only single linalg.generic in the defn + }); + + if(!candidateOp) { + LLVM_DEBUG(llvm::dbgs() << "No linalg.generic found in defn " << opName << "\n"); + continue; + } + + LLVM_DEBUG(llvm::dbgs() << "Found linalg.generic in defn " << opName << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Candidate numInputs=" << candidateOp.getNumDpsInputs() + << ", target numInputs=" << numInputs << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Candidate numOutputs=" << candidateOp.getNumDpsInits() + << ", target numOutputs=" << numOutputs << "\n"); + + // Check if this linalg.generic matches our target + DenseMap nodeMapping; + DenseMap operationMapping; // Added for findCorrespondingValue + if (candidateOp.getNumDpsInputs() == numInputs && + candidateOp.getNumDpsInits() == numOutputs && + areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) && + areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) && + areRegionsEquivalent(genericOp.getRegion(), candidateOp.getRegion(), nodeMapping, operationMapping)) { + LLVM_DEBUG(llvm::dbgs() << "MATCH FOUND for defn " << opName << "\n"); + foundMatch = true; + matchedOpName = opName; + matchedOperationMapping = operationMapping; // Store the operation mapping + matchedDefnOp = defnOp; // Store the matched defnOp + } else { + LLVM_DEBUG(llvm::dbgs() << "No match for defn " << opName << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Input/output check: " + << (candidateOp.getNumDpsInputs() == numInputs) << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Maps check: " + << areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Iterator types check: " + << areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Regions check: " + << areRegionsEquivalent(genericOp.getRegion(), candidateOp.getRegion(), nodeMapping, operationMapping) << "\n"); + } + + if (foundMatch) { + return KernelMatchResult{matchedOpName, matchedOperationMapping, matchedDefnOp}; + } + } + + return failure(); +} + +// Rewrite pattern to convert linalg.generic to kernel ops +class LinalgGenericToKernelPattern : public OpRewritePattern { +public: + LinalgGenericToKernelPattern(MLIRContext *context, + kernel::DefnCollectionOp collectionOp) + : OpRewritePattern(context), collectionOp(collectionOp) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + + LLVM_DEBUG(llvm::dbgs() << "matchAndRewrite called for genericOp:\n"); + LLVM_DEBUG(llvm::dbgs() << genericOp << "\n"); + + auto module = genericOp->getParentOfType(); + //Check if the parent of the generic op is a kernel.defn + if (auto parentOp = genericOp->getParentOp()) { + if (isa(parentOp)) { + LLVM_DEBUG(llvm::dbgs() << "Skipping genericOp inside kernel.defn\n"); + return failure(); + } + } + + // Try to match with a defn in the collection + auto matchResult = matchGenericWithDefn(genericOp, collectionOp); + if (failed(matchResult)) { + LLVM_DEBUG(llvm::dbgs() << "No match found in collection\n"); + return failure(); + } + + StringRef opName = matchResult->kernelName; + LLVM_DEBUG(llvm::dbgs() << "Match found with kernel: " << opName << "\n"); + + // Find the matched kernel.defn operation + kernel::DefnOp matchedDefnOp = matchResult->matchedDefnOp; + + if (!matchedDefnOp) { + return failure(); + } + + // Check if the kernel.defn already exists in the target module + kernel::DefnOp existingDefn; + module.walk([&](kernel::DefnOp defnOp) { + if (defnOp.getSymName() == opName) { + // Check if this defn is inside a defn_collection (template) or at module level (callable) + if (!defnOp->getParentOfType()) { + existingDefn = defnOp; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + + // If the kernel.defn doesn't exist in the module, copy it + if (!existingDefn) { + // Clone the matched kernel.defn operation + rewriter.setInsertionPointToStart(module.getBody()); + auto clonedDefn = rewriter.clone(*matchedDefnOp.getOperation()); + (void)clonedDefn; // Suppress unused variable warning + } + + // Create kernel.launch operation to replace the genericOp + Location loc = genericOp.getLoc(); + + // Set insertion point to the genericOp location + rewriter.setInsertionPoint(genericOp); + + // Get the kernel function signature to map all arguments + Block &kernelBlock = matchedDefnOp.getRegion().front(); + auto kernelArgs = kernelBlock.getArguments(); + + // Use the operationMapping from the match result (no need to call areRegionsEquivalent again) + const DenseMap &operationMapping = matchResult->operationMapping; + + // Use unified approach: map ALL kernel arguments to their corresponding actual values + SmallVector operands; + LLVM_DEBUG(llvm::dbgs() << "Starting to map " << kernelArgs.size() << " kernel arguments\n"); + + for (BlockArgument kernelArg : kernelArgs) { + Value actualValue = findCorrespondingValue(kernelArg, operationMapping, genericOp); + if (!actualValue) { + LLVM_DEBUG(llvm::dbgs() << "Failed to find corresponding value for kernel arg #" + << kernelArg.getArgNumber() << " - returning failure\n"); + return failure(); // Could not find corresponding value + } + operands.push_back(actualValue); + } + + LLVM_DEBUG(llvm::dbgs() << "Successfully mapped all kernel arguments, creating kernel.launch\n"); + + // Get kernel function signature types for casting + auto kernelFuncType = matchedDefnOp.getFunctionType(); + auto kernelInputTypes = kernelFuncType.getInputs(); + auto kernelResultTypes = kernelFuncType.getResults(); + + // Cast operands to match kernel signature types if needed + SmallVector castedOperands; + for (size_t i = 0; i < operands.size(); ++i) { + Value operand = operands[i]; + Type expectedType = (i < kernelInputTypes.size()) ? kernelInputTypes[i] : operand.getType(); + + if (operand.getType() != expectedType) { + // Insert tensor.cast for type conversion + if (isa(operand.getType()) && isa(expectedType)) { + LLVM_DEBUG(llvm::dbgs() << "Casting operand " << i << " from " << operand.getType() + << " to " << expectedType << "\n"); + auto castOp = rewriter.create(loc, expectedType, operand); + castedOperands.push_back(castOp.getResult()); + } else { + // For non-tensor types, use the operand as-is + castedOperands.push_back(operand); + } + } else { + castedOperands.push_back(operand); + } + } + + // Get result types from the generic operation + TypeRange originalResultTypes = genericOp.getResultTypes(); + + // Create the kernel.launch operation with casted operands and kernel result types + auto launchOp = rewriter.create( + loc, + kernelResultTypes, // Use kernel result types for the launch op + opName, + castedOperands // Use casted operands + ); + + // Cast results back to original types if needed + SmallVector finalResults; + for (size_t i = 0; i < launchOp.getResults().size(); ++i) { + Value result = launchOp.getResult(i); + Type originalType = (i < originalResultTypes.size()) ? originalResultTypes[i] : result.getType(); + + if (result.getType() != originalType) { + // Insert tensor.cast to convert back to original type + if (isa(result.getType()) && isa(originalType)) { + LLVM_DEBUG(llvm::dbgs() << "Casting result " << i << " from " << result.getType() + << " to " << originalType << "\n"); + auto castOp = rewriter.create(loc, originalType, result); + finalResults.push_back(castOp.getResult()); + } else { + finalResults.push_back(result); + } + } else { + finalResults.push_back(result); + } + } + + // Replace the generic operation with the final results + rewriter.replaceOp(genericOp, finalResults); + + return success(); + } + +private: + kernel::DefnCollectionOp collectionOp; +}; + +// Pass to apply the rewrite pattern +struct LinalgToKernelPass : public LinalgToKernelBase { + using LinalgToKernelBase::LinalgToKernelBase; + + // Constructor that allows setting the kernel library path + LinalgToKernelPass() = default; + LinalgToKernelPass(const std::string& libraryPath) : externalLibraryPath(libraryPath) {} + + void runOnOperation() override { + ModuleOp module = getOperation(); + + kernel::DefnCollectionOp collectionOp = nullptr; + OwningOpRef externalModule; + // Determine which path to use for kernel library + std::string effectiveLibraryPath = externalLibraryPath; + // If no external path was provided via constructor, try the command line option + if (effectiveLibraryPath.empty()) { + effectiveLibraryPath = std::string(kernelLibraryPath); + } + + //// Debug output + //llvm::errs() << "DEBUG: externalLibraryPath = '" << externalLibraryPath << "'\n"; + //llvm::errs() << "DEBUG: kernelLibraryPath = '" << std::string(kernelLibraryPath) << "'\n"; + //llvm::errs() << "DEBUG: effectiveLibraryPath = '" << effectiveLibraryPath << "'\n"; + + // Check if we should load kernel definitions from an external file + if (!effectiveLibraryPath.empty()) { + //llvm::errs() << "DEBUG: Loading kernel definitions from external file: " << effectiveLibraryPath << "\n"; + // Load kernel definitions from external file + std::string errorMessage; + auto memoryBuffer = mlir::openInputFile(effectiveLibraryPath, &errorMessage); + if (!memoryBuffer) { + module.emitError("Failed to open kernel library file: ") << effectiveLibraryPath + << " - " << errorMessage; + return signalPassFailure(); + } + + // Parse the external file + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); + + externalModule = mlir::parseSourceFile(sourceMgr, &getContext()); + if (!externalModule) { + module.emitError("Failed to parse kernel library file: ") << effectiveLibraryPath; + return signalPassFailure(); + } + + // Debug: Print the loaded external module + //llvm::errs() << "DEBUG: Successfully loaded external module:\n"; + //externalModule->print(llvm::errs()); + //llvm::errs() << "\n"; + + // Find the kernel.defn_collection in the external module + externalModule->walk([&](kernel::DefnCollectionOp op) { + collectionOp = op; + LLVM_DEBUG(llvm::dbgs() << "Found kernel.defn_collection in external module\n"); + return WalkResult::interrupt(); + }); + + if (!collectionOp) { + module.emitError("No kernel.defn_collection found in external kernel library: ") + << effectiveLibraryPath; + return signalPassFailure(); + } + + // Debug: Print the found collection + //llvm::errs() << "DEBUG: kernel.defn_collection contents:\n"; + //llvm::errs() << collectionOp; + //llvm::errs() << collectionOp.getOperation(); + //llvm::errs() << "\n"; + } else { + // Find the kernel.defn_collection in the current module (original behavior) + module.walk([&](kernel::DefnCollectionOp op) { + collectionOp = op; + return WalkResult::interrupt(); + }); + + if (!collectionOp) { + module.emitError("No kernel.defn_collection found in module. " + "Either include one in the input module or specify " + "--kernel-library-path to load from external file."); + return signalPassFailure(); + } + } + + // Apply the rewrite pattern + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), collectionOp); + + //llvm::errs() << "DEBUG: kernel.defn_collection contents:\n"; + //llvm::errs() << collectionOp.getOperation(); + //llvm::errs() << "\n"; + //llvm::errs() << collectionOp; + //llvm::errs() << "\n"; + + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) + return signalPassFailure(); + } + +private: + std::string externalLibraryPath; +}; + +} // namespace + +namespace mlir::polygeist { + +// Create a pass to convert linalg.generic to kernel +std::unique_ptr createLinalgToKernelPass() { + return std::make_unique(); +} + +// Create a pass to convert linalg.generic to kernel with kernel library path +std::unique_ptr createLinalgToKernelPass(const std::string& kernelLibraryPath) { + return std::make_unique(kernelLibraryPath); +} + +} // namespace mlir::polygeist \ No newline at end of file diff --git a/lib/polygeist/Passes/RaiseToLinalg.cpp b/lib/polygeist/Passes/RaiseToLinalg.cpp index 254d3a11881b..7182cb4b0cca 100644 --- a/lib/polygeist/Passes/RaiseToLinalg.cpp +++ b/lib/polygeist/Passes/RaiseToLinalg.cpp @@ -1,21 +1,24 @@ #include "PassDetails.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" #include "polygeist/Passes/Passes.h" #include "llvm/Support/Debug.h" -#include "mlir/IR/AffineExpr.h" #define DEBUG_TYPE "raise-to-linalg" @@ -23,175 +26,642 @@ using namespace mlir; using namespace mlir::arith; using namespace polygeist; using namespace affine; +using namespace linalg; -namespace { -struct RaiseAffineToLinalg : public AffineRaiseToLinalgBase { - void runOnOperation() override; -}; -} // namespace - -// Also want to add support for affine.for ( ) { linalg.generic } -> bigger linalg.generic -// Also probably want to try to do { linalg.generc1(); linalg.generic2(); } -> bigger linalg.generic() +// Also want to add support for affine.for ( ) { linalg.generic } -> bigger +// linalg.generic Also probably want to try to do { linalg.generc1(); +// linalg.generic2(); } -> bigger linalg.generic() /* affine.for() { affine.for() { - } + } affine.for() { } } */ struct Condition { - bool ifTrue; - AffineIfOp op; - Condition(bool ifTrue, AffineIfOp op) : ifTrue(ifTrue), op(op) {} + bool ifTrue; + AffineIfOp op; + Condition(bool ifTrue, AffineIfOp op) : ifTrue(ifTrue), op(op) {} }; bool isLinearInIndex(AffineExpr expr, size_t idx) { - if (!expr.isFunctionOfDim(idx)) { - return true; - } + if (!expr.isFunctionOfDim(idx)) { + return true; + } - if (expr.getKind() == AffineExprKind::DimId) { - return true; - } + if (expr.getKind() == AffineExprKind::DimId) { + return true; + } - if (expr.getKind() == AffineExprKind::Add) { - auto binop = expr.cast(); - return isLinearInIndex(binop.getLHS(), idx) && isLinearInIndex(binop.getRHS(), idx); - } - if (expr.getKind() == AffineExprKind::Mul) { - auto binop = expr.cast(); - return (isLinearInIndex(binop.getLHS(), idx) && !binop.getRHS().isFunctionOfDim(idx)) || - (isLinearInIndex(binop.getRHS(), idx) && !binop.getLHS().isFunctionOfDim(idx)); - } + if (expr.getKind() == AffineExprKind::Add) { + auto binop = expr.cast(); + return isLinearInIndex(binop.getLHS(), idx) && + isLinearInIndex(binop.getRHS(), idx); + } + if (expr.getKind() == AffineExprKind::Mul) { + auto binop = expr.cast(); + return (isLinearInIndex(binop.getLHS(), idx) && + !binop.getRHS().isFunctionOfDim(idx)) || + (isLinearInIndex(binop.getRHS(), idx) && + !binop.getLHS().isFunctionOfDim(idx)); + } - return false; + return false; } bool isLinearInIndex(AffineMap map, size_t idx) { - for (auto expr : map.getResults()) { - if (!isLinearInIndex(expr, idx)) - return false; + for (auto expr : map.getResults()) { + if (!isLinearInIndex(expr, idx)) + return false; + } + return true; +} + +AffineExpr shiftDimsDown1(AffineExpr expr, unsigned numDims, unsigned offset) { + SmallVector dims; + for (unsigned idx = 0; idx < offset; ++idx) + dims.push_back(getAffineDimExpr(idx, expr.getContext())); + for (unsigned idx = offset; idx < numDims; ++idx) + dims.push_back(getAffineDimExpr(idx - 1, expr.getContext())); + return expr.replaceDimsAndSymbols(dims, {}); +} + +// This is reducing the number of input dims in expression by 1 +AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, unsigned offset) { + assert(offset <= expr.getNumDims()); + return AffineMap::get(expr.getNumDims() - 1, expr.getNumSymbols(), + llvm::map_to_vector<4>(expr.getResults(), + [&](AffineExpr e) { + return shiftDimsDown1( + e, expr.getNumDims(), + offset); + }), + expr.getContext()); +} + +// Helper function to check if an operation dominates the target region +bool dominatesTarget(Operation* op, Region* targetRegion) { + return op->getParentRegion()->isAncestor(targetRegion); +} + +Value recursiveCloneWithDominanceCheck( + OpBuilder& builder, + Value value, + Region* targetRegion, + IRMapping& mapping, + DenseSet& processedOps) { + + // If value is already mapped, return the mapped value + if (mapping.contains(value)) { + return mapping.lookup(value); } - return true; + + // Handle block arguments + if (auto blockArg = dyn_cast(value)) { + if (blockArg.getParentBlock()->getParent()->isAncestor(targetRegion)) { + mapping.map(value, value); + return value; + } else { + llvm::errs() << "Non-dominating block argument encountered\n"; + return nullptr; + } + } + + Operation* defOp = value.getDefiningOp(); + if (!defOp) { + return value; + } + + // Check if this operation dominates the target region + if (dominatesTarget(defOp, targetRegion)) { + // Operation dominates, use it directly + mapping.map(value, value); + return value; + } + + // Avoid processing the same operation multiple times + if (processedOps.contains(defOp)) { + // Operation was already processed, should be in mapping + auto resultNum = cast(value).getResultNumber(); + auto mappedOp = mapping.lookup(defOp->getResult(0)).getDefiningOp(); + auto clonedValue = mappedOp->getResult(resultNum); + mapping.map(value, clonedValue); + return clonedValue; + } + + // Check if operation is safe to clone + if (!isReadOnly(defOp)) { + llvm::errs() << "Cannot clone non-read-only operation: " << *defOp << "\n"; + return nullptr; + } + + processedOps.insert(defOp); + + // Recursively process ALL operands first to populate the mapping + for (Value operand : defOp->getOperands()) { + Value clonedOperand = recursiveCloneWithDominanceCheck( + builder, operand, targetRegion, mapping, processedOps); + if (!clonedOperand) { + return nullptr; + } + // clonedOperand is automatically added to mapping by recursive call + } + + // Now clone the operation using the populated mapping + Operation* clonedOp = builder.clone(*defOp, mapping); + + // The clone automatically maps all results, so we can just return what we need + auto resultNum = cast(value).getResultNumber(); + return clonedOp->getResult(resultNum); +} + +// Check if the affine apply is a constant and return the constant value +std::optional getConstantFromAffineApply(AffineApplyOp applyOp) { + AffineMap map = applyOp.getAffineMap(); + + // Must have no dimensions and no symbols + if (map.getNumDims() != 0 || map.getNumSymbols() != 0) { + return std::nullopt; + } + + // Must have exactly one result that is a constant + if (map.getNumResults() != 1) { + return std::nullopt; + } + + // Check if the single result is a constant expression + AffineExpr result = map.getResult(0); + if (auto constExpr = result.dyn_cast()) { + return constExpr.getValue(); + } + + return std::nullopt; } - AffineExpr shiftDimsDown1(AffineExpr expr, unsigned numDims, - unsigned offset) { - SmallVector dims; - for (unsigned idx = 0; idx < offset; ++idx) - dims.push_back(getAffineDimExpr(idx, expr.getContext())); - for (unsigned idx = offset; idx < numDims; ++idx) - dims.push_back(getAffineDimExpr(idx - 1, expr.getContext())); - return expr.replaceDimsAndSymbols(dims, {}); - } - -//This is reducing the number of input dims in expression by 1 - AffineMap shiftDimsDown1(AffineMap expr, unsigned numDim, - unsigned offset) { - assert(offset <= expr.getNumDims()); - return AffineMap::get(expr.getNumDims() - 1, expr.getNumSymbols(), - llvm::map_to_vector<4>( - expr.getResults(), - [&](AffineExpr e) { - return shiftDimsDown1(e, expr.getNumDims(), offset); - }), - expr.getContext()); - } - -// Given an affine map `oldmap`, memref `val`, and corresponding input values (which are a list of indicies, then symbols), -// and a loop index `ind` produce the following: -// 1. A (potentially new) memref value `newval` which does not have any dependence on `ind` +// Given an affine map `oldmap`, memref `val`, and corresponding input values +// (which are a list of indicies, then symbols), and a set of loop indices +// `indices` produce the following: +// 1. A (potentially new) memref value `newval` which does not have any +// dependence on `indices` // and -// 2. an affine map `newmap` which takes a single index (`ind`) and produces indices into `newval` such that -// indexing `newval[map(ind)]` produces the same result as indexing the original map. -std::pair remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, Value val, Value idx, Value idx_size, int loopLowerBound, int loopStepSize, mlir::OperandRange vals) { - // First we need to remove any dependence on the loop index from the affine map - SmallVector vals_without_idx; - ssize_t dim_idx = -1; - //To check if induction variable of for loop in an operand of this op (load/store) - for (auto &&[i, v] : llvm::enumerate(vals)) { - if (v == idx) { - // Offset we're replacing must be an index (not a symbol). - // If we guarantee to run AffineCFG first, this should always be true. - assert(i < oldmap.getNumDims()); - // There should only be one use of the index. - assert(dim_idx == -1); - dim_idx = i; - continue; +// 2. an affine map `newmap` which takes size(indices) values (`indices`) and +// produces indices into `newval` such that +// indexing `newval[map(indices)]` produces the same result as indexing the +// original map. +// check_reduction is set true, when passed from store/linalg.generic's output +// variable. And it is returned true, only if index was not encountered in +// oldmap operands and check_reduction was set true. +Value remap_in_affine_dim(bool &legal, OpBuilder &builder, AffineMap oldmap, + Value memref_val, Value index, Value bound, AffineApplyOp lower_bound, + int firstNDims, ValueRange oldmap_operands, + Value origmemref, bool &check_reduction) { + + int lower_bound_val = getConstantFromAffineApply(lower_bound).value_or(0); + + assert(oldmap_operands.size() == + oldmap.getNumSymbols() + oldmap.getNumDims()); + // Operands which don't correspond to indices + SmallVector operands_without_indices; + ssize_t dimidx = -1; + for (auto [i, v] : llvm::enumerate(oldmap_operands)) { + if (v == nullptr) { + assert(i < firstNDims); + continue; + } + assert(i >= firstNDims); + if (v != index) { + // Check if the symbol value is read-only or defined in a scope where it + // is always visible. + if (auto ba = dyn_cast(v)) { + // check if it dominates the current scope + if (ba.getParentBlock()->getParent()->isAncestor( + builder.getBlock()->getParent())) + operands_without_indices.push_back(v); + else { + assert(false); + legal = false; + return nullptr; } - vals_without_idx.push_back(v); + } else { + auto op = v.getDefiningOp(); + // check if this dominates the current scope + if (op->getParentRegion()->isAncestor( + builder.getBlock()->getParent())) { + operands_without_indices.push_back(v); + } else if (isReadOnly(op)) { + // if not, check if it is readnone + // Technically this isn't quite sufficient yet, and does require that + // the operands to this op are also able to be hoisted, but for now we + // will assume this + auto op2 = builder.clone(*op); + operands_without_indices.push_back( + op2->getResult(cast(v).getResultNumber())); + } else { + // if so clone it in the right scope + // otherwise set illegal and don't continue + assert(false); + legal = false; + return nullptr; + } + } + } else + dimidx = i; + } + if ((dimidx == -1) && (check_reduction)) + check_reduction = true; + else + check_reduction = false; + + SmallVector dimReplacements; + size_t validSims = 0; + size_t validDims = 0; + for (int i = 0; i < oldmap.getNumDims(); i++) { + if (i < firstNDims) { + assert(i != dimidx); + dimReplacements.push_back(builder.getAffineDimExpr(validDims)); + validDims++; + } else if (i == dimidx) { + dimReplacements.push_back(builder.getAffineDimExpr(validDims) + builder.getAffineConstantExpr(lower_bound_val)); + validDims++; + } else { + // TODO: Why are we using symbol here instead of dim? + dimReplacements.push_back(builder.getAffineSymbolExpr(validSims)); + validSims++; } + } - if (dim_idx != -1 && !isLinearInIndex(oldmap, dim_idx)) { - legal = false; - return {val, oldmap}; + SmallVector symReplacements; + for (int i = 0; i < oldmap.getNumSymbols(); i++) { + if (i + oldmap.getNumDims() == dimidx) { + symReplacements.push_back(builder.getAffineDimExpr(validDims) + builder.getAffineConstantExpr(lower_bound_val)); + validDims++; + } else { + symReplacements.push_back(builder.getAffineSymbolExpr(validSims)); + validSims++; } + } + if (validSims != operands_without_indices.size()) { + llvm::errs() << " oldmap: " << oldmap << "\n"; + llvm::errs() << " dimidx=" << dimidx << "\n"; + llvm::errs() << " index: " << index << "\n"; + llvm::errs() << " oldmap_operands: size=" << oldmap_operands.size() + << "\n"; + for (auto op : oldmap_operands) { + if (op) { + llvm::errs() << " -" << op << " &" << op.getAsOpaquePointer() << "\n"; + } else { + llvm::errs() << " -" + << "null" + << " &nullptr\n"; + } + } + llvm::errs() << " validSims: " << validSims << "\n"; + llvm::errs() << " operands_without_indices: size=" + << operands_without_indices.size() << "\n"; + for (auto op : operands_without_indices) { + llvm::errs() << " -" << op << " &" << op.getAsOpaquePointer() << "\n"; + } + } + assert(validSims == operands_without_indices.size()); + auto map2 = oldmap.replaceDimsAndSymbols(dimReplacements, symReplacements, + firstNDims + 1/*Number of dims in new map*/, + operands_without_indices.size() /*Number of symbols in new map*/); + + SmallVector idx_sizes; + for (size_t i = 0; i < firstNDims; i++) { + // memref.dimOp captures the size of the memref + if (auto submap = origmemref.getDefiningOp()) + idx_sizes.push_back(submap.getSizes()[i]); + else + llvm_unreachable("Won't reach this case"); + // idx_sizes.push_back(builder.create(origmemref.getLoc(), + // origmemref, i)); + } + idx_sizes.push_back(bound); + + legal = true; + SmallVector sizes(idx_sizes.size(), mlir::ShapedType::kDynamic); + for (auto sz : idx_sizes) { + DenseSet processedOps; + IRMapping mapping; + auto clonedOp = recursiveCloneWithDominanceCheck(builder, sz, builder.getBlock()->getParent(), mapping, processedOps); + if (!clonedOp) { + legal = false; + return nullptr; + } + operands_without_indices.push_back(clonedOp); + } + //for (auto sz : idx_sizes) { + // // Check if the symbol value is read-only or defined in a scope where it is + // // always visible. + // if (auto ba = dyn_cast(sz)) { + // // check if it dominates the current scope + // if (ba.getParentBlock()->getParent()->isAncestor( + // builder.getBlock()->getParent())) + // operands_without_indices.push_back(sz); + // else { + // llvm::errs() << " value is a non-dominating block arg: " << sz << "\n"; + // legal = false; + // assert(false); + // return nullptr; + // } + // } else { + // auto op = sz.getDefiningOp(); + // // check if this dominates the current scope + // if (op->getParentRegion()->isAncestor(builder.getBlock()->getParent())) { + // operands_without_indices.push_back(sz); + // } else if (isReadOnly(op)) { + // // if not, check if it is readnone + // // Technically this isn't quite sufficient yet, and does require that + // // the operands to this op are also able to be hoisted, but for now we + // // will assume this + // // We need to clone the op along and check if it's operands are dominating or not, else do a recursive clone + // auto op2 = builder.clone(*op); + // operands_without_indices.push_back( + // op2->getResult(cast(sz).getResultNumber())); + // } else { + // llvm::errs() << " op is not readonly: " << *op << "\n"; + // // if so clone it in the right scope + // // otherwise set illegal and don't continue + // legal = false; + // assert(false); + // return nullptr; + // } + // } + //} + auto ty = MemRefType::get( + sizes, cast(memref_val.getType()).getElementType()); + + ////TODO: Can we have a case where stride is not 1? + //Value stride = builder.create(memref_val.getLoc(), 1); + + //// Create a subview op using lower bound, stride and size + //// Convert AffineApplyOp to its result Value and wrap in ValueRange + //Value lowerBoundValue = lower_bound.getResult(); + //auto subViewOp = builder.create( + // memref_val.getLoc(), // Location + // memref_val, // Source memref + // ValueRange{lowerBoundValue}, // Offsets (array) + // ValueRange{bound}, // Sizes (array) + // ValueRange{stride} // Strides (array) + //); + + //Value subview = subViewOp.getResult(); + + return builder.create( + memref_val.getLoc(), ty, memref_val, operands_without_indices, map2); +} - // Evaluate offsets as oldmap replacing idx with 0, and evaluating at the remaining variables +// store A[...] +// val = load A[...] - //Instead of lower bound we are using 0 (assumption as the lower bound) - AffineMap offsetMap = oldmap; - if (dim_idx != -1) { - offsetMap = oldmap.replace(builder.getAffineDimExpr(dim_idx), builder.getAffineConstantExpr(loopLowerBound),offsetMap.getNumDims(), offsetMap.getNumSymbols()); - offsetMap = shiftDimsDown1(offsetMap, oldmap.getNumDims(), dim_idx); - } +/* prevA : + store A + val is now prevA +*/ - //Instead of using loop step we are using 1 (Assumption as the stride size) - AffineMap strideMap = oldmap; - if (dim_idx != -1) { - strideMap = oldmap.replace(builder.getAffineDimExpr(dim_idx), builder.getAffineConstantExpr(loopLowerBound + loopStepSize),strideMap.getNumDims(), strideMap.getNumSymbols()); - strideMap = shiftDimsDown1(strideMap, oldmap.getNumDims(), dim_idx); - } +/* - //Subtracting maps of stride and offset, gives you the offset value in the result of the map - { - SmallVector subtracts; - for (auto &&[lhs, rhs] : llvm::zip(strideMap.getResults(), offsetMap.getResults())) { - subtracts.push_back(lhs - rhs); - } - strideMap = AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), subtracts, builder.getContext()); +f(%memref ) + +%memref = ... + +affine.for { + + %inp = .. subview %memref [ ... ] + + linalg.generic %inp #map { + body() } +} - // Expression to index into the generated subview given the loop index - SmallVector loop_idxs; - // List of starting offsets into the subview - SmallVector offsets; - SmallVector sizes; - SmallVector strides; +-> - for (auto &&[expr, offset_expr, stride_expr] : llvm::zip(oldmap.getResults(), offsetMap.getResults(),strideMap.getResults() )) { - offsets.push_back(builder.create(val.getLoc(),AffineMap::get(offsetMap.getNumDims(), offsetMap.getNumSymbols(), offset_expr, builder.getContext()), vals_without_idx)); //What is there are symbols in the expression? - strides.push_back(builder.create(val.getLoc(),AffineMap::get(strideMap.getNumDims(), strideMap.getNumSymbols(), stride_expr, builder.getContext()), vals_without_idx)); //What is there are symbols in the expression? - if (!expr.isFunctionOfDim(dim_idx)) { - loop_idxs.push_back(builder.getAffineConstantExpr(0)); - sizes.push_back(builder.create(val.getLoc(), 1)); - } else { - loop_idxs.push_back(builder.getAffineDimExpr(0)); - sizes.push_back(idx_size); - } + +affine.for j { + + linalg.generic %memref #map2(j) { + body() } +} + + + + +#map2 = #map with the indexing done to %inp + - auto newval = builder.create(val.getLoc(), val, offsets, sizes, strides); - legal = true; - //Does this need fix? Here we are constraining to dims as 1 and symbols as 0, should it be, original - return {newval, AffineMap::get(/*dims*/1, /*symbols*/0, loop_idxs, builder.getContext())}; + + + +%memref = .. subview %memref_base [ ... ] + +linalg.generic %[[[memref]]] [[[[#map]]]]([[[[operands]]]]) { + body() } +-> + + +output_memref = memref_base +output_map = subvmap() + + compose +# uts are memref, map, and operands +# outputs are o +memref[map(operands)] ==== output_memref[output_map(output_operands)] + + + +bas= memref<40x40> + +B + +u + +tput_memref, output_map and output_operands +# possible intermediate is ... + +getLinalgArgMap(memref, map, operands to map [e.g. input symbols/dims]) + if memref is alloca/unknown/etc + return memref/map/operands + else + memref = subview memref_base[map2(operands2)] + + return memref_base and a new output_map such that + memref_base[output_map(output_operands)] === memref[map(operands)] + + + -// store A[...] -// val = load A[...] -/* prevA : - store A - val is now prevA */ +// Suppose we have a memref expression E=input[affine.map(operands)] +// if input = memref.subview A[starts, offsets] +// can we rewrite E as A[affine.map2(operands2)] +// We update lgMap and lgOperands in place with this coresponding map2 and +// operands2 +LogicalResult getLinalgArgMap(Operation *loop, Value &input, AffineMap &lgMap, + SmallVector &lgOperands) { + OpBuilder builder(loop->getContext()); + + while (Operation *defOp = input.getDefiningOp()) { + + assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); + // If the input is defined outside of the loop, we are finished. + if (!loop->isAncestor(defOp)) + break; + + if (auto SM = dyn_cast(defOp)) { + auto submap = SM.getMap(); + + // TODO: Do we achieve anything with this compose? + // As lgMap in our case is 1 to 1 identity map + auto composeMap = submap.compose(lgMap); + + SmallVector operands0; + + // First the dims + for (size_t i = 0; i < lgMap.getNumDims(); i++) + operands0.push_back(lgOperands[i]); + + // Then the symbols of submap + for (size_t i = 0; i < submap.getNumSymbols(); i++) + operands0.push_back(SM.getSymbols()[i]); + + // Then the symbols of lgMap + for (size_t i = 0; i < lgMap.getNumSymbols(); i++) + operands0.push_back(lgOperands[i + lgMap.getNumDims()]); + + lgMap = composeMap; + lgOperands = operands0; + input = SM.getMemref(); + assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); + continue; + } + + // if (auto SV = dyn_cast(defOp)) { + + // // TODO update map with the new indexing from here + + // // Create affine map + // // i. Track number of running dims and symbols + // // ii. shift dims and symbols to generate shifted expressions. + // // Extract corresponding operands + // // Use affineMap::get with numOperands and numSymbols along with shifted + // // expressions to get a map. Use affine map simplify to simplify this + + // SmallVector startExprs; + // SmallVector strideExprs; + // SmallVector dimOperands; + // SmallVector symOperands; + // for (auto &&[first, second] : llvm::zip(SV.getOffsets(), + // SV.getStrides())) { + // for (auto &&[index, val] : llvm::enumerate(SmallVector({first, + // second}))) { + // auto &exprOutput = (index == 0) ? startExprs : strideExprs; + // // Only support constants, symbols, or affine apply as offsets + // if (auto cop = val.getDefiningOp()) { + // exprOutput.push_back(builder.getAffineConstantExpr(cop.value())); + // continue; + // } else if (auto cop = val.getDefiningOp()) { + // exprOutput.push_back(builder.getAffineConstantExpr(cop.value())); + // continue; + // } + // if (auto ba = dyn_cast(val)) { + // Block *parentBlock = ba.getOwner(); + // if (isa(parentBlock->getParentOp())) { + // exprOutput.push_back( + // builder.getAffineDimExpr(dimOperands.size())); + // dimOperands.push_back(ba); + // continue; + + // } + // } + + // auto valOp = val.getDefiningOp(); + // // Defined outside loop, consider it a symbol [for now] + // //if (!valOp || loop->isAncestor(defOp)) { + // if (valOp&&!loop->isAncestor(defOp)) { + // exprOutput.push_back( + // builder.getAffineSymbolExpr(symOperands.size())); + // symOperands.push_back(val); + // continue; + // } + + // //TODO: Maybe it's a case to add, but are we sure we need it for + // starts and offsets + // // and not for operands + // if (auto apply = dyn_cast(valOp)) { + // auto map = apply.getAffineMap(); + // auto *scope = affine::getAffineScope(valOp)->getParentOp(); + // DominanceInfo DI(scope); + // auto map_operands = apply.getOperands(); + // //fully2ComposeAffineMapAndOperands(builder, &map, &map_operands, + // DI); + //// Instead of using loop step we are using 1 (Assumption as the stride + /// size) + // auto newexpr = map.shiftDims(dimOperands.size()) + // .shiftSymbols(symOperands.size()); + + // for (auto expr : newexpr.getResults()) { + // exprOutput.push_back(expr); + // } + + // for (size_t i = 0; i < map.getNumDims(); i++) + // dimOperands.push_back(apply.getOperands()[i]); + + // for (size_t i = 0; i < map.getNumSymbols(); i++) + // symOperands.push_back(apply.getOperands()[i + + // map.getNumDims()]); + + // continue; + // } + + // //return failure(); + // } + // } + + // SmallVector inputExprs; + // for (auto expr : lgMap.shiftDims(dimOperands.size()) + // .shiftSymbols(symOperands.size()).getResults()) { + // inputExprs.push_back(expr); + // } + // for (size_t i = 0; i < lgMap.getNumDims(); i++) + // dimOperands.push_back(lgOperands[i]); + + // for (size_t i = 0; i < lgMap.getNumSymbols(); i++) + // symOperands.push_back(lgOperands[i + lgMap.getNumDims()]); + + // SmallVector mergedExprs; + // for (auto && [start, stride, idx] : + // llvm::zip(startExprs, strideExprs, inputExprs)) { + // mergedExprs.push_back(start + idx * stride); + // } + + // lgMap = + // AffineMap::get(dimOperands.size(), symOperands.size(), mergedExprs, + // loop->getContext()); + // lgOperands.clear(); + // lgOperands.insert(lgOperands.begin(), dimOperands.begin(), + // dimOperands.end()); + // lgOperands.insert(lgOperands.begin()+lgOperands.size(), + // symOperands.begin(), symOperands.end()); input = SV.getSource(); break; + //} + + // return failure(); + } + assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); + return success(); +} struct AffineForOpRaising : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -199,260 +669,736 @@ struct AffineForOpRaising : public OpRewritePattern { LogicalResult matchAndRewrite(affine::AffineForOp loop, PatternRewriter &rewriter) const final { + auto module = loop->getParentOfType(); + // Don't handle accumulations in registers for the moment, we can have // a separate pattern move them into memref's if (loop.getNumResults() != 0) { - return failure(); + return failure(); } SmallVector, AffineLoadOp>> loads; SmallVector, AffineStoreOp>> stores; + SmallVector, GenericOp>> linalgGenerics; + bool check_reduction; + // TODO Also collect all the linalg generics! // Check that the only operations within the region are either: // affine.load, affine.store, affine.if, affine.yield // Additionally, for each load/store, remember what conditions are // required for that load or store to execute. - auto result = loop->walk([&](Operation* op) { - if (op == loop) return WalkResult::advance(); - // TODO extend this, any non-memory operation is also legal here. - // mul, add, etc (we can just check propety) - if (isa(op)) { - return WalkResult::advance(); + auto result = loop->walk([&](Operation *op) { + if (op == loop) + return WalkResult::advance(); + // TODO extend this, any non-memory operation is also legal here. + // mul, add, etc (we can just check propety) + if (isa(op)) { + return WalkResult::advance(); + } + if (isa(op) || isa(op)) { + Operation *cur = op->getParentOp(); + std::vector conditions; + while (cur != loop) { + auto ifstmt = dyn_cast(cur); + if (!ifstmt) { + return WalkResult::interrupt(); + } + bool ifTrue = + ifstmt.getThenRegion().isAncestor(cur->getParentRegion()); + conditions.emplace_back(ifTrue, ifstmt); + cur = ifstmt->getParentOp(); } - if (isa(op)) { - Operation *cur = op->getParentOp(); - std::vector conditions; - while (cur != loop) { - auto ifstmt = dyn_cast(cur); - if (!ifstmt) { - return WalkResult::interrupt(); - } - bool ifTrue = ifstmt.getThenRegion().isAncestor(cur->getParentRegion()); - conditions.emplace_back(ifTrue, ifstmt); - cur = ifstmt->getParentOp(); - } - if (auto load = dyn_cast(op)) { - loads.emplace_back(conditions, load); - } else { - auto store = cast(op); - stores.emplace_back(conditions, store); - } - return WalkResult::advance(); - } - if (isReadNone(op)) { - return WalkResult::advance(); + if (auto linalgGeneric = dyn_cast(op)) { + linalgGenerics.emplace_back(conditions, linalgGeneric); + } else if (auto load = dyn_cast(op)) { + loads.emplace_back(conditions, load); + } else { + auto store = cast(op); + stores.emplace_back(conditions, store); } - return WalkResult::interrupt(); + return WalkResult::advance(); + } + // IsReadNone takes care of apply and subview too? + if (isReadNone(op)) { + return WalkResult::advance(); + } + return WalkResult::interrupt(); }); - - if (result.wasInterrupted()) return failure(); + + if (result.wasInterrupted()) + return failure(); DominanceInfo DI(loop); - // Check that all of the stores do not alias the loaded values (otherwise we could get an incorrect result) - // TODO we can extend this and handle things like reductions, but we're going to start easy for now - // TODO + // Check that all of the stores do not alias the loaded values (otherwise we + // could get an incorrect result) + // TODO we can extend this and handle things like reductions, but we're + // going to start easy for now + // TODO DenseMap stores_map; for (auto &&[_, store] : stores) { - for (auto &&[_, load]: loads) { - if (mayAlias(load.getMemref(), store.getMemref())) { - // We have one exception in this case -- if the load and store are from the exact same location, it is permitted. - if (load.getMemref() == store.getMemref() && - load.getAffineMap() == store.getAffineMap() && - load.getIndices() == store.getIndices() && DI.dominates((Operation*)load,(Operation*)store)) { - stores_map[load] = store; - continue; - } - return failure(); - } + for (auto &&[_, load] : loads) { + if (mayAlias(load.getMemref(), store.getMemref())) { + // We have one exception in this case -- if the load and store are + // from the exact same location, it is permitted. + if (load.getMemref() == store.getMemref() && + load.getAffineMap() == store.getAffineMap() && + load.getIndices() == store.getIndices() && + DI.dominates((Operation *)load, (Operation *)store)) { + // Example case where load does not dominate stores - if the load + // was conditional. Or, store followed by load? Q. Can't we still + // overlook the aliasing? + stores_map[load] = store; + continue; + } + //return failure(); } - for (auto &&[_, store2]: stores) { - if (store == store2) continue; - if (mayAlias(store.getMemref(), store2.getMemref())) { - return failure(); - } + } + for (auto &&[_, store2] : stores) { + if (store == store2) + continue; + if (mayAlias(store.getMemref(), store2.getMemref())) { + return failure(); } + } } // Check that any other loads / stores do not alias with any linalg generics - // We're going to need to upgrade the defn of mayAlias for subviews (aka mayAlias(subview, x) -> mayAlias(operand(subview), x)) + // We're going to need to upgrade the defn of mayAlias for subviews (aka + // mayAlias(subview, x) -> mayAlias(operand(subview), x)) - SmallVector inputs; + SmallVector inputs, outputs; SmallVector affineMaps; + SmallVector indexingMaps; - //if (loop.getStep() != 1) { - // return failure(); - //} + // if (loop.getStep() != 1) { + // return failure(); + // } - // our remapper currently assumes 0 start to bound. + // our remapper currently assumes 0 start to bound. if (!loop.hasConstantLowerBound() /*|| loop.getConstantLowerBound() != 0*/) { - return failure(); + return failure(); } // compute this correctly later. auto ubMap = loop.getUpperBoundMap(); auto ubOperands = loop.getUpperBoundOperands(); - if (!ubMap || ubMap.getNumResults() != 1) return failure(); + if (!ubMap || ubMap.getNumResults() != 1) + return failure(); // Retrieve the lower bound auto lbMap = loop.getLowerBoundMap(); auto lbOperands = loop.getLowerBoundOperands(); - if (!lbMap || lbMap.getNumResults() != 1) return failure(); - - auto ub = loop.getSingleUpperBound(); - if (!ub) return failure(); + if (!lbMap || lbMap.getNumResults() != 1) + return failure(); - auto lb = loop.getSingleLowerBound(); - if (!lb) return failure(); - + //auto ub = loop.getSingleUpperBound(); + //if (!ub) + // return failure(); - if (!loop.hasConstantUpperBound()) { - return failure(); - } + //auto lb = loop.getSingleLowerBound(); + //if (!lb) + // return failure(); + + //if (!loop.hasConstantUpperBound()) { + // return failure(); + //} // Retrieve the step size int64_t step = loop.getStep(); // Get the single result expressions AffineExpr ubExpr = ubMap.getResult(0); - auto ubValue = rewriter.create(loop.getLoc(), ubMap, ubOperands); - + auto ubValue = + rewriter.create(loop.getLoc(), ubMap, ubOperands); + AffineExpr lbExpr = lbMap.getResult(0); - auto lbValue = rewriter.create(loop.getLoc(), lbMap, lbOperands); + auto lbValue = + rewriter.create(loop.getLoc(), lbMap, lbOperands); //// Ensure the bounds are constant expressions - auto ubConst = ubExpr.dyn_cast(); - auto lbConst = lbExpr.dyn_cast(); - if (!ubConst || !lbConst) return failure(); + //auto ubConst = ubExpr.dyn_cast(); + //auto lbConst = lbExpr.dyn_cast(); + //if (!ubConst || !lbConst) + // return failure(); // Compute the loop size - //int64_t loopSize = ubConst.getValue() - lbConst.getValue(); + // int64_t loopSize = ubConst.getValue() - lbConst.getValue(); auto loopSize = rewriter.create(loop.getLoc(), ubValue, lbValue); - - //Value loopSize = rewriter.create(loop.getLoc(), loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), *ub, *lb); - - // current spec is going to be indexed off of the loop var in isolation - for (auto &&[conds, load] : loads) { - // Only support unconditional loads for the moment - if (conds.size() != 0) return failure(); - if (stores_map.find(load) != stores_map.end()) { - // We have a store that represents this load. - continue; + // Value loopSize = rewriter.create(loop.getLoc(), + // loop.getConstantUpperBound());//rewriter.create(loop.getLoc(), + // *ub, *lb); + + for (auto &&[conds, lg] : linalgGenerics) { + + // This captures the indexing map attribute from the linalg.generic being + // processed + ArrayAttr indexingMapsAttr = lg.getIndexingMaps(); + + int idx = 0; + // Iterate over input arguments + for (const Value input : lg.getInputs()) { + // Is this needed? + if (conds.size() != 0) + return failure(); + + // TODO: Implement this + // lgMap comes from offset of memref.subview, + // lgOperands comes from operands of memref.subview + + const AffineMap lgMap0 = + cast(indexingMapsAttr[idx]).getAffineMap(); + AffineMap lgMap = lgMap0; + SmallVector lgOperands; + for (int i = 0; i < lgMap.getNumDims(); i++) { + lgOperands.push_back(nullptr); } - bool legal = true; - - auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, load.getAffineMap(), load.getMemref(), loop.getInductionVar(), - loopSize, lbConst.getValue(), step, load.getMapOperands()); + Value lgMemref = input; + + // At input, this contains, current input (i.e. probably a subview) + // an lgMap which is obtained from LG's indexing map for corresponding + // input lgOperands contains current input (i.e probably a subview) - if (!legal) return failure(); + // Gives output ... + assert(lgOperands.size() == lgMap.getNumSymbols() + lgMap.getNumDims()); + auto result = getLinalgArgMap(loop, lgMemref, lgMap, lgOperands); + + if (!result.succeeded()) + return failure(); + + bool legal = true; + + // Takes input's/output's, affineMap of load/store (here lgMap ?), + // induction variable corresponding to the loop + // Memref corresponding the the memory accessed (in this case subview ?) + // loopSize, lower and upper bounds + // Get operands for load/store (here ?) to find dependent dim + + // Gives output newMemref which is a subviewOp, + // newAffineMap which is the LG's indexing map corresponding this + // inp/output + + // This takes load and store maps and then creates + // affine.apply+subview+linalg.generic For this case: LG within ForOp - + // Inputs should be : load map extracted from subviewOp + // Returns LG with indexingMap and subview with affine.apply - which + // are correct + + // TODO: Or is it num dims? + // size_t firstNDims = lgMap.getResults().size(); + size_t firstNDims = lgMap.getNumDims(); + check_reduction = false; + auto newMemref = remap_in_affine_dim( + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, lbValue, + firstNDims, ValueRange(lgOperands), input, check_reduction); + if (!legal) + return failure(); + + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims + 1); + + // TODO: need to mergre previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); inputs.push_back(newMemref); - } - // TODO Push all of the inputs to the linalg generics (modifying maps as needed) - - SmallVector outputs; - // Store we may need to reindex into a splat potentially later, but for now we'll be lazy - for (auto &&[conds, store] : stores) { - // Only support unconditional loads for the moment - if (conds.size() != 0) return failure(); + idx++; + } + + // Iterate over output arguments + for (const Value output : lg.getOutputs()) { + // Is this needed? + if (conds.size() != 0) + return failure(); + + const AffineMap lgMap0 = + cast(indexingMapsAttr[idx]).getAffineMap(); + AffineMap lgMap = lgMap0; + + SmallVector lgOperands; + for (int i = 0; i < lgMap.getNumDims(); i++) { + lgOperands.push_back(nullptr); + } + Value lgMemref = output; + + auto result = getLinalgArgMap(loop, lgMemref, lgMap, lgOperands); + + if (!result.succeeded()) + return failure(); bool legal = true; - - auto &&[newMemref, newAffineMap] = remap_in_affine_dim(legal, rewriter, store.getAffineMap(), store.getMemref(), loop.getInductionVar(), - loopSize, lbConst.getValue(), step, store.getMapOperands()); - if (!legal) return failure(); + size_t firstNDims = lgMap.getNumDims(); + check_reduction = true; + auto newMemref = remap_in_affine_dim( + legal, rewriter, lgMap, lgMemref, loop.getInductionVar(), loopSize, lbValue, + firstNDims, ValueRange(lgOperands), output, check_reduction); + if (!legal) + return failure(); + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims + 1); + // TODO: need to merge previous indexing maps and new affine maps affineMaps.push_back(newAffineMap); outputs.push_back(newMemref); + } + } + + // current spec is going to be indexed off of the loop var in isolation + for (auto &&[conds, load] : loads) { + // Only support unconditional loads for the moment + if (conds.size() != 0) + return failure(); + + if (stores_map.find(load) != stores_map.end()) { + // We have a store that represents this load. + continue; + } + + size_t firstNDims = 0; + bool legal = true; + + check_reduction = false; + auto newMemref = remap_in_affine_dim( + legal, rewriter, load.getAffineMap(), load.getMemref(), + loop.getInductionVar(), loopSize, lbValue, firstNDims, load.getMapOperands(), + load.getMemref(), check_reduction); + + if (!legal) + return failure(); + + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims + 1); + affineMaps.push_back(newAffineMap); + inputs.push_back(newMemref); + } + // TODO Push all of the inputs to the linalg generics (modifying maps as + // needed) + + // SmallVector outputs; + // Store we may need to reindex into a splat potentially later, but for now + // we'll be lazy + for (auto &&[conds, store] : stores) { + // Only support unconditional loads for the moment + if (conds.size() != 0) + return failure(); + + bool legal = true; + + size_t firstNDims = 0; + + check_reduction = true; + auto newMemref = remap_in_affine_dim( + legal, rewriter, store.getAffineMap(), store.getMemref(), + loop.getInductionVar(), loopSize, lbValue, firstNDims, store.getMapOperands(), + store.getMemref(), check_reduction); + + if (!legal) { + return failure(); + } + + auto newAffineMap = rewriter.getMultiDimIdentityMap(firstNDims + 1); + affineMaps.push_back(newAffineMap); + outputs.push_back(newMemref); } // TODO Push all of the outputs to the linalg generics // TODO presently if linalg generic exists, assert there are no load/stores + if ((linalgGenerics.size() > 0) && + ((loads.size() != 0) || (stores.size() != 0))) { + return failure(); + } + // TODO assert only zero or one linalg generic exists + if (!(linalgGenerics.size() == 1 || linalgGenerics.size() == 0)) { + // assert(false); + return failure(); + } + SmallVector iteratorTypes; - // TODO if linalg generic exists, make this iterator type prepend to the existing iterators - iteratorTypes.push_back((stores_map.size() == 0) ? utils::IteratorType::parallel : utils::IteratorType::reduction); + // TODO if linalg generic exists, make this iterator type prepend to the + // existing iterators + // TODO: Just store check is not sufficient, there has to be a check for + // bool is_parallel = stores_map.size() == 0; + // TODO determine if linalg generic, whether to create parallel or + // reduction by looking at memory patterns of maps + if (linalgGenerics.size() == 1) { + // determine whether now we write to ourselves + } + + iteratorTypes.push_back(check_reduction ? utils::IteratorType::reduction + : utils::IteratorType::parallel); + + if (linalgGenerics.size() == 1) { + for (auto attr : linalgGenerics[0].second.getIteratorTypesArray()) + iteratorTypes.push_back(attr); + } StringAttr empty = StringAttr::get(loop.getContext()); auto genericOp = rewriter.create( - loop.getLoc(), TypeRange(), inputs, outputs, affineMaps, iteratorTypes, - empty, - empty); + loop.getLoc(), TypeRange(), inputs, outputs, affineMaps, iteratorTypes, + empty, empty); - // TODO if doing the linalg generic case, ignore a lot of the below and instead of injecting the old body of the affine.for, move the inner linalg.generic body - // and also add a new induction variable + // TODO if doing the linalg generic case, ignore a lot of the below and + // instead of injecting the old body of the affine.for, move the inner + // linalg.generic body and also add a new induction variable auto blk = &*loop.getRegion().begin(); rewriter.setInsertionPointToStart(blk); // This index will replace the use of the affine index - auto idx = rewriter.create(loop.getLoc(), rewriter.getIndexAttr(0)); + auto idx = rewriter.create(loop.getLoc(), + 0); rewriter.replaceAllUsesWith(loop.getInductionVar(), idx); auto &body = genericOp.getRegion(); body.takeBody(loop.getRegion()); - blk->eraseArguments(0, blk->getNumArguments()); for (auto &&[conds, load] : loads) { - if (stores_map.find(load) != stores_map.end()) { - // We have a store that represents this load. - continue; - } - auto arg = blk->addArgument(load.getType(), load.getLoc()); - rewriter.replaceOp(load, arg); - + if (stores_map.find(load) != stores_map.end()) { + // We have a store that represents this load. + continue; + } + auto arg = blk->addArgument(load.getType(), load.getLoc()); + rewriter.replaceOp(load, arg); } - for (auto &&[conds, store] : stores) { - auto arg = blk->addArgument(store.getValueToStore().getType(), store.getLoc()); + auto arg = + blk->addArgument(store.getValueToStore().getType(), store.getLoc()); - SmallVector inverted; - for (auto && [map_load, map_store] : stores_map) { - if (map_store == store) { - inverted.push_back(map_load); - } - } - for (size_t i=0; i inverted; + for (auto &&[map_load, map_store] : stores_map) { + if (map_store == store) { + inverted.push_back(map_load); } + } + for (size_t i = 0; i < inverted.size(); i++) { + stores_map.erase(inverted[i]); + auto tmp = inverted[i]; + inverted[i] = nullptr; + rewriter.replaceOp(tmp, arg); + } } SmallVector toreturn; + for (auto genPair : linalgGenerics) { + auto genOp = genPair.second; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(genOp); + auto &genBlock = genOp->getRegion(0).front(); + auto term = genBlock.getTerminator(); + mlir::IRMapping map; + for (auto arg : genBlock.getArguments()) { + auto arg2 = blk->addArgument(arg.getType(), arg.getLoc()); + map.map(arg, arg2); + } + for (auto &op : genBlock.without_terminator()) { + rewriter.clone(op, map); + } + for (auto op : term->getOperands()) { + toreturn.push_back(map.lookupOrDefault(op)); + } + // llvm::errs() << genOp->getParentOfType() << "\n"; + rewriter.eraseOp(genOp); + } + for (auto &&[conds, store] : stores) { - toreturn.push_back(store.getValueToStore()); - rewriter.eraseOp(store); + toreturn.push_back(store.getValueToStore()); + rewriter.eraseOp(store); } rewriter.eraseOp(blk->getTerminator()); rewriter.setInsertionPointToEnd(blk); rewriter.create(loop.getLoc(), toreturn); + auto func = loop->getParentOfType(); rewriter.eraseOp(loop); // return success! return success(); } }; -void RaiseAffineToLinalg::runOnOperation() { - RewritePatternSet patterns(&getContext()); - // TODO add the existing canonicalization patterns - // + subview of an affine apply -> subview - patterns.insert(&getContext()); +struct AffineParallelFission : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp parallelOp, + PatternRewriter &rewriter) const override { + + auto module = parallelOp->getParentOfType(); + // Collect all top-level nested loops (affine.parallel or affine.for) + SmallVector nestedLoops; + Block *body = parallelOp.getBody(); + + for (auto &op : body->without_terminator()) { + if (isa(op)) { + nestedLoops.push_back(&op); + } else { + // Only allow pure nested loops - reject any other operations + return failure(); + } + } + + // Need at least 2 nested loops to perform fission + if (nestedLoops.size() < 2) { + return failure(); + } + + // Convert reductions ArrayAttr to ArrayRef + SmallVector reductionKinds; + for (auto attr : parallelOp.getReductions()) { + auto enumAttr = cast(attr); + reductionKinds.push_back(enumAttr.getValue()); + } + + // Convert steps to ArrayRef + SmallVector stepValues; + for (auto step : parallelOp.getSteps()) { + stepValues.push_back(step); + } + + for (Operation *nestedLoop : nestedLoops) { + + // Create new parallel loops for each nested loop + rewriter.setInsertionPoint(parallelOp); + + // Create a new outer parallel loop with same bounds + auto newParallelOp = rewriter.create( + parallelOp.getLoc(), + parallelOp.getResultTypes(), + reductionKinds, + SmallVector{parallelOp.getLowerBoundsMap()}, + parallelOp.getLowerBoundsOperands(), + SmallVector{parallelOp.getUpperBoundsMap()}, + parallelOp.getUpperBoundsOperands(), + stepValues + ); + + // Move the nested loop into the new outer loop + Block *newBody = newParallelOp.getBody(); + // Remove the existing terminator + rewriter.eraseOp(newBody->getTerminator()); + + // Set insertion point to the new body before cloning + rewriter.setInsertionPointToEnd(newBody); + + // Clone the nested loop into the new body + IRMapping mapping; + // Map the induction variables (use getIVs() instead of getInductionVars()) + for (auto [oldIV, newIV] : llvm::zip(parallelOp.getIVs(), + newParallelOp.getIVs())) { + mapping.map(oldIV, newIV); + } + + // Clone the operation (it will be automatically inserted at the current insertion point) + rewriter.clone(*nestedLoop, mapping); + + // Ensure insertion point is at the end of the outer parallel loop's body + rewriter.setInsertionPointToEnd(newBody); + + // Add the terminator back + rewriter.create(parallelOp.getLoc()); + } + + // Remove the original parallel loop + rewriter.eraseOp(parallelOp); + + return success(); + } + +private: + // Helper to check if an operation has no side effects that would + // prevent loop fission + bool isMemoryOrControlFlowNeutral(Operation *op) const { + // Allow constants, arithmetic, and other side-effect-free ops + if (isa(op)) return true; + if (op->hasTrait()) return true; + + // Check if it's a pure operation (no memory effects) + if (auto effectInterface = dyn_cast(op)) { + SmallVector effects; + effectInterface.getEffects(effects); + return effects.empty(); + } + + // Conservative: if we can't prove it's safe, assume it's not + return false; + } +}; + +struct AffineParallelToFor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AffineParallelOp parallelOp, + PatternRewriter &rewriter) const override { + + // Skip if there are reductions - they need special handling + if (!parallelOp.getReductions().empty()) { + return failure(); + } + + // Skip if there are result types - parallel loops with returns need special handling + if (!parallelOp.getResultTypes().empty()) { + return failure(); + } + + Location loc = parallelOp.getLoc(); + + // Get the bounds and steps + auto lowerBounds = parallelOp.getLowerBoundsMap(); + auto upperBounds = parallelOp.getUpperBoundsMap(); + auto steps = parallelOp.getSteps(); + auto lowerOperands = parallelOp.getLowerBoundsOperands(); + auto upperOperands = parallelOp.getUpperBoundsOperands(); + auto ivs = parallelOp.getIVs(); + + // Start building nested for loops from outermost to innermost + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(parallelOp); + + // Create nested affine.for loops + SmallVector forOps; + SmallVector newIVs; + + for (unsigned i = 0; i < ivs.size(); ++i) { + // Extract bounds for this dimension + auto lbMap = lowerBounds.getSliceMap(i, 1); + auto ubMap = upperBounds.getSliceMap(i, 1); + int64_t step = steps[i]; + + auto forOp = rewriter.create( + loc, + lowerOperands, lbMap, + upperOperands, ubMap, + step + ); + + forOps.push_back(forOp); + newIVs.push_back(forOp.getInductionVar()); + + // Set insertion point for next loop or body + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Move the body content from parallel to innermost for loop + Block *parallelBody = parallelOp.getBody(); + Block *targetBody = forOps.empty() ? nullptr : forOps.back().getBody(); + + if (!targetBody) { + return failure(); + } + + // Create mapping for induction variables + IRMapping mapping; + for (auto [parallelIV, newIV] : llvm::zip(ivs, newIVs)) { + mapping.map(parallelIV, newIV); + } + + // Clone operations from parallel body to for body (excluding terminator) + for (auto &op : parallelBody->without_terminator()) { + rewriter.clone(op, mapping); + } + + // Remove the original parallel loop + rewriter.eraseOp(parallelOp); + + return success(); + } +}; + +// namespace { +// struct RaiseAffineToLinalg +// : public AffineRaiseToLinalgBase { + +// std::shared_ptr patterns; + +// LogicalResult initialize(MLIRContext *context) override { +// RewritePatternSet owningPatterns(context); +// for (auto *dialect : context->getLoadedDialects()) +// dialect->getCanonicalizationPatterns(owningPatterns); +// for (RegisteredOperationName op : context->getRegisteredOperations()) +// op.getCanonicalizationPatterns(owningPatterns, context); + +// owningPatterns.insert(&getContext()); + +// patterns = std::make_shared( +// std::move(owningPatterns)); +// return success(); +// } +// void runOnOperation() override { +// GreedyRewriteConfig config; +// (void)applyPatternsAndFoldGreedily(getOperation(), *patterns, config); +// } +// }; +// } // namespace + +namespace { +struct RaiseAffineToLinalgPipeline + : public AffineRaiseToLinalgPipelineBase { + void runOnOperation() override; +}; +} // namespace + +void RaiseAffineToLinalgPipeline::runOnOperation() { + // Create a nested pass manager to run the pipeline on functions + OpPassManager pm(getOperation()->getName()); + + // Create a nested pass manager for function operations + OpPassManager &funcPM = pm.nest(); + + // Add affine-parallelize pass first (runs on func.func) + funcPM.addPass(mlir::affine::createAffineParallelizePass()); + + // Add our raise-affine-to-linalg pass second (also runs on func.func) + funcPM.addPass(createRaiseAffineToLinalgPass()); + + // Canonicalize after raise-to-linalg to eliminate submaps and other patterns + funcPM.addPass(createCanonicalizerPass()); + + // Run the pipeline + if (failed(runPipeline(pm, getOperation()))) { + // Warn but don't fail the pass - convergence issues shouldn't kill output + getOperation()->emitWarning("Pipeline didn't converge completely, but continuing anyway"); + } +} + +namespace { +struct RaiseAffineToLinalg + : public AffineRaiseToLinalgBase { + void runOnOperation() override; +}; +} // namespace + +void RaiseAffineToLinalg::runOnOperation() { GreedyRewriteConfig config; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + + // Step 1: Apply fission pattern first + { + RewritePatternSet fissionPatterns(&getContext()); + fissionPatterns.insert(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(fissionPatterns), config))) { + getOperation()->emitWarning("AffineParallelFission didn't converge, continuing anyway"); + } + } + + // Step 2: Apply parallel-to-for conversion + { + RewritePatternSet parallelToForPatterns(&getContext()); + parallelToForPatterns.insert(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(parallelToForPatterns), config))) { + getOperation()->emitWarning("AffineParallelToFor didn't converge, continuing anyway"); + } + } + + // Step 3: Apply raising pattern + { + RewritePatternSet raisingPatterns(&getContext()); + raisingPatterns.insert(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(raisingPatterns), config))) { + getOperation()->emitWarning("AffineForOpRaising didn't converge, continuing anyway"); + } + } } namespace mlir { @@ -460,5 +1406,9 @@ namespace polygeist { std::unique_ptr createRaiseAffineToLinalgPass() { return std::make_unique(); } + +std::unique_ptr createRaiseAffineToLinalgPipelinePass() { + return std::make_unique(); +} } // namespace polygeist } // namespace mlir diff --git a/lib/polygeist/Passes/RemoveIterArgs.cpp b/lib/polygeist/Passes/RemoveIterArgs.cpp new file mode 100644 index 000000000000..2a3e9ea4edc6 --- /dev/null +++ b/lib/polygeist/Passes/RemoveIterArgs.cpp @@ -0,0 +1,279 @@ +#include "PassDetails.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "polygeist/Passes/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "remove-scf-iter-args" + +using namespace mlir; +using namespace mlir::arith; +using namespace polygeist; +using namespace scf; +using namespace affine; + +struct RemoveSCFIterArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + + ModuleOp module = forOp->getParentOfType(); + if (!forOp.getRegion().hasOneBlock()) + return failure(); + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + auto loc = forOp->getLoc(); + bool changed = false; + llvm::SetVector removed; + llvm::MapVector steps; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < numIterArgs; i++) { + auto ba = forOp.getRegionIterArgs()[i]; + auto init = forOp.getInits()[i]; + auto lastOp = yieldOp->getOperand(i); + + // General Case(TODO): + // ALGo: + // 1. Create an alloca(stack) variable + // How to know it's dims? It should be based on number of reduction + // loops + // 2. Initialize it with init value just outside the for loop if init + // value is non-zero + // 3. memref.load that value in the for loop + // 4. Replace all the uses of the iter_arg with the loaded value + // 5. Add a memref.store for the value to be yielded + // 6. Replace all uses of for-loops yielded value with a single inserted + // memref.load + // Special case: + // ALGo: + // Optimize away memref.store and memref.load, if the only users of + // memref.load are memref.store (can use affine-scalrep pass for that ? No + // it does store to load forwarding) What we need is forwarding of local + // store to final store and deleting the intermediate alloca created. This + // is only possible if the user of alloca is a storeOp. + // 1. Identify the single store of the for loop result + // 2. Initialize it with iter arg init, outside the for loop. (TODO) + // 3. Do a load from the memref + // 4. move the store to memref inside the loop. + + auto result = forOp.getResult(i); + if (result.hasOneUse()) { + auto storeOp = dyn_cast(*result.getUsers().begin()); + if (storeOp) { + { + rewriter.setInsertionPointToStart(forOp.getBody()); + auto memrefLoad = rewriter.create( + forOp.getLoc(), storeOp.getMemref(), storeOp.getIndices()); + rewriter.replaceAllUsesWith(ba, memrefLoad.getResult()); + } + { + rewriter.setInsertionPoint(yieldOp); + rewriter.create(forOp.getLoc(), lastOp, + storeOp.getMemref(), + storeOp.getIndices()); + storeOp.erase(); + } + } else { + return failure(); + } + } + // else{ + // alloca = rewriter.create( + // forOp.getLoc(), MemRefType::get(ArrayRef(), + // forOp.getType()), ValueRange()); + // //Skipping init for now + + // auto memrefLoad = rewriter.create( + // forOp.getLoc(), alloca.getMemref(), op.getIndices()); + // rewriter.replaceOp(op, memrefLoad.getResult()); + + // rewriter.create(forOp.getLoc(), lastOp, alloca, + // forOp.getBody()->getArguments()); + + // rewriter.replaceAllUsesWith(result,) + //} + + rewriter.setInsertionPointToStart(forOp.getBody()); + // rewriter.replaceAllUsesWith(ba, replacementIV); + changed = true; + } + + if (!changed) + return failure(); + + rewriter.setInsertionPoint(forOp); + auto newForOp = rewriter.create( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); + if (!newForOp.getRegion().empty()) + newForOp.getRegion().front().erase(); + assert(newForOp.getRegion().empty()); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().begin()); + + // Delete region args + llvm::BitVector toDelete(numIterArgs + 1); + for (unsigned i = 0; i < numIterArgs; i++) + toDelete[i + 1] = true; + newForOp.getBody()->eraseArguments(toDelete); + + SmallVector newYields; + { + ValueRange empty; + rewriter.setInsertionPoint(yieldOp); + auto newYieldOp = rewriter.create(loc); + // rewriter.replaceOpWithNewOp(yieldOp, newYieldOp); + rewriter.eraseOp(yieldOp); + } + + rewriter.setInsertionPoint(newForOp); + rewriter.eraseOp(forOp); + + return success(); + } +}; + +// General Case(TODO): +// ALGo: +// 1. Create an alloca(stack) variable +// How to know it's dims? It should be based on number of reduction +// loops +// 2. Initialize it with init value just outside the for loop if init +// value is non-zero +// 3. memref.load that value in the for loop +// 4. Replace all the uses of the iter_arg with the loaded value +// 5. Add a memref.store for the value to be yielded +// 6. Replace all uses of for-loops yielded value with a single inserted +// memref.load +// Special case: +// ALGo: +// Optimize away memref.store and memref.load, if the only users of +// memref.load are memref.store (can use affine-scalrep pass for that ? No +// it does store to load forwarding) What we need is forwarding of local +// store to final store and deleting the intermediate alloca created. This +// is only possible if the user of alloca is a storeOp. +// 1. Identify the single store of the for loop result +// 2. Initialize it with iter arg init, outside the for loop. (TODO) +// 3. Do a load from the memref +// 4. move the store to memref inside the loop. + +struct RemoveAffineIterArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(affine::AffineForOp forOp, + PatternRewriter &rewriter) const override { + + ModuleOp module = forOp->getParentOfType(); + rewriter.setInsertionPoint(forOp); + + unsigned numIterArgs = forOp.getNumRegionIterArgs(); + if (numIterArgs == 0) + return failure(); + + auto loc = forOp->getLoc(); + auto yieldOp = + cast(forOp.getBody()->getTerminator()); + + auto ba = forOp.getRegionIterArgs()[numIterArgs - 1]; + auto init = forOp.getInits()[numIterArgs - 1]; + auto lastOp = yieldOp->getOperand(numIterArgs - 1); + + auto result = forOp.getResult(numIterArgs - 1); + if (result.hasOneUse()) { + auto storeOp = + dyn_cast(*result.getUsers().begin()); + if (storeOp) { + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(forOp.getBody()); + auto memrefLoad = rewriter.create( + forOp.getLoc(), storeOp.getMemref(), storeOp.getMap(), + storeOp.getMapOperands()); + rewriter.replaceAllUsesWith(ba, memrefLoad.getResult()); + } + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(yieldOp); + rewriter.create( + forOp.getLoc(), lastOp, storeOp.getMemref(), storeOp.getMap(), + storeOp.getMapOperands()); + storeOp.erase(); + } + } else { + return failure(); + } + } + else{ + return failure(); + } + + SmallVector newIterArgs(forOp.getInits().drop_back()); + auto newForOp = rewriter.create( + loc, forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), + forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), + forOp.getStep(), newIterArgs); + + if (!newForOp.getRegion().empty()) + newForOp.getRegion().front().erase(); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().begin()); + + // Delete region args + llvm::BitVector toDelete(numIterArgs + 1); + toDelete[numIterArgs] = true; + newForOp.getBody()->eraseArguments(toDelete); + + SmallVector newYields; + { + OpBuilder::InsertionGuard guard(rewriter); + ValueRange empty; + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getOperands().drop_back()); + } + + for(int i = 0; i < numIterArgs-1; i++){ + rewriter.replaceAllUsesWith(forOp.getResult(i), newForOp.getResult(i)); + } + + rewriter.eraseOp(forOp); + return success(); + } +}; + +namespace { +struct RemoveIterArgs : public RemoveIterArgsBase { + + void runOnOperation() override { + GreedyRewriteConfig config; + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + patterns.insert(patterns.getContext()); + patterns.insert(patterns.getContext()); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + signalPassFailure(); + return; + } + } +}; +} // namespace + +namespace mlir { +namespace polygeist { +std::unique_ptr createRemoveIterArgsPass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir \ No newline at end of file diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir new file mode 100644 index 000000000000..65a5a9ef0adf --- /dev/null +++ b/test/polygeist-opt/debufferize.mlir @@ -0,0 +1,496 @@ +//polygeist-opt --canonicalize --linalg-debufferize --canonicalize debufferize.mlir + +#map16 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map17 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> +#map18 = affine_map<(d0, d1, d2, d3) -> (d1, d0)> +#map19 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +#map22 = affine_map<(d0, d1) -> (d1, d0)> + + module @in_place_add{ + func.func @in_place_add(%value: f32) { + %c0 = arith.constant 0 : index + %buffer = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + + module @in_place_add2{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + + module @in_place_cond_add{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + return + } + } + + module @in_place_add_for{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.for %i = %c0 to %c10 step %c1 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + return + } + } + + //Case when buffer is captured + module @in_place_add_for_loop_carried{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + //%buffer = memref.alloca() : memref<128xf32> + %result = scf.for %i = %c0 to %c10 step %c1 iter_args(%buf = %buffer) -> (memref<128xf32>) { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf : memref<128xf32>) + outs(%buf : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + scf.yield %buf : memref<128xf32> + } + return + } + } + module @cross_buffer_add{ + func.func @in_place_add(%buf: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %buf2 = memref.alloca() : memref<128xf32> + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf : memref<128xf32>) + outs(%buf2 : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf2 : memref<128xf32>) + outs(%buf : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + return + } + } + + module @in_place_add_for_loop_carried_cross_buffer{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %buffer2 = memref.alloca() : memref<128xf32> + %result:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%buf = %buffer, %buf2 = %buffer2) -> (memref<128xf32>, memref<128xf32>) { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf : memref<128xf32>) + outs(%buf2 : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buf2 : memref<128xf32>) + outs(%buf : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + scf.yield %buf, %buf2 : memref<128xf32>, memref<128xf32> + } + return + } + } + +// //TODO: Doesn't bufferize --affine loop carried iter_args doesn't canonicalizes (missing pattern?) +// module @in_place_add_for_loop_carried3{ +// func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %c10 = arith.constant 10 : index +// %buffer2 = memref.alloca() : memref<128xf32> +// %result:2 = affine.for %i = %c0 to %c10 iter_args(%buf = %buffer, %buf2 = %buffer2) -> (memref<128xf32>, memref<128xf32>) { +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buf : memref<128xf32>) +// outs(%buf2 : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// linalg.yield %sum : f32 +// } +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buf2 : memref<128xf32>) +// outs(%buf : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// %sum2 = arith.addf %sum, %value : f32 +// linalg.yield %sum2 : f32 +// } +// affine.yield %buf, %buf2 : memref<128xf32>, memref<128xf32> +// } +// return +// } +// } + +// module @in_place_add_for_loop_affine{ +// func.func @in_place_add(%buf: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %c10 = arith.constant 10 : index +// %buf2 = memref.alloca() : memref<128xf32> +// affine.for %i = %c0 to %c10 { +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buf : memref<128xf32>) +// outs(%buf2 : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// linalg.yield %sum : f32 +// } +// linalg.generic { +// indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], +// iterator_types = ["parallel"] +// } ins(%buf2 : memref<128xf32>) +// outs(%buf : memref<128xf32>) { +// ^bb0(%in: f32, %out: f32): +// %sum = arith.addf %in, %value : f32 +// %sum2 = arith.addf %sum, %value : f32 +// linalg.yield %sum2 : f32 +// } +// } +// return +// } +// } + + + module @in_place_cond_add_followed_by_add{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + + module @in_place_cond_add_followed_by_add2{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1, %cond2: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond2 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + %sum3 = arith.addf %sum2, %value : f32 + linalg.yield %sum3 : f32 + } + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + } + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + + module @in_place_cond_add_followed_by_add3{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1, %cond2: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond2 { + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + } + scf.if %cond2 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + %sum3 = arith.addf %sum2, %value : f32 + linalg.yield %sum3 : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + + module @conv_2 { + func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %3 = arith.muli %in, %in_0 : i32 + %4 = arith.addi %out, %3 : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } + } + + module @harris_score_with_gradient_extra_kernel { + //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + func.func @main(%input: memref<518x518xi32>, %0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + //%score = memref.alloca() : memref<512x512xi32> + //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.addi %out_7, %4 : i32 + %6 = arith.muli %in, %in_6 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7, %5 : i32, i32 + } + linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): + %4 = arith.muli %in, %in : i32 + %5 = arith.muli %4, %in_6 : i32 + %6 = arith.addi %out_8, %5 : i32 + %7 = arith.muli %in_5, %in_5 : i32 + %8 = arith.muli %7, %in_6 : i32 + %9 = arith.addi %out_7, %8 : i32 + %10 = arith.muli %in, %in_5 : i32 + %11 = arith.muli %10, %in_6 : i32 + %12 = arith.addi %out, %11 : i32 + linalg.yield %12, %9, %6 : i32, i32, i32 + } + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.muli %in_6, %in_6 : i32 + %6 = arith.subi %4, %5 : i32 + %7 = arith.addi %in, %in_5 : i32 + %8 = arith.muli %7, %c4_i32 : i32 + %9 = arith.muli %8, %7 : i32 + %10 = arith.subi %6, %9 : i32 + linalg.yield %10 : i32 + } + return %c0_i32 : i32 + } + } + + module @for_loop_within_for_loop{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.for %i = %c0 to %c10 step %c1 { + scf.for %j = %c0 to %c10 step %c1 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + return + } + } + + module @for_loop_with_if_with_for{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.for %i = %c0 to %c10 step %c1 { + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + scf.for %j = %c0 to %c10 step %c1 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + } + } + return + } + } diff --git a/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir b/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir new file mode 100644 index 000000000000..dbe09418ed75 --- /dev/null +++ b/test/polygeist-opt/linalg_debufferize_tile_fusion.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --one-shot-bufferize="bufferize-function-boundaries" --func-bufferize --tensor-bufferize --finalizing-bufferize --convert-linalg-to-affine-loops --raise-scf-to-affine -split-input-file -verify-diagnostics | FileCheck %s +// To test bufferization : pva-opt %s -test-transform-dialect-interpreter --one-shot-bufferize="bufferize-function-boundaries test-analysis-only print-conflicts" +#map1 = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> + +memref.global @out : memref<512x64xi32> = uninitialized +memref.global @rhs : memref<64x64xi32> = uninitialized +memref.global @filter : memref<4x4xi32> = uninitialized +memref.global @im : memref<515x67xi32> = uninitialized +// Output after debufferization +// func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c512 = arith.constant 512 : index +// %c64 = arith.constant 64 : index +// %c4 = arith.constant 4 : index +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im : memref<515x67xi32> +// %1 = memref.get_global @filter : memref<4x4xi32> +// %2 = memref.get_global @out : memref<512x64xi32> +// %rhs_memref = memref.get_global @rhs : memref<64x64xi32> +// %4 = bufferization.to_tensor %0 : memref<515x67xi32> +// %5 = bufferization.to_tensor %1 : memref<4x4xi32> +// %x = tensor.empty() : tensor<512x64xi32> +// %out = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%4, %5 : tensor<515x67xi32>, tensor<4x4xi32>) outs(%x : tensor<512x64xi32>) { +// ^bb0(%in: i32, %in_0: i32, %out: i32): +// %6 = arith.muli %in, %in_0 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7 : i32 +// } -> tensor<512x64xi32> +// +// %materialize = bufferization.to_memref %out : memref<512x64xi32> +// memref.copy %materialize, %2 : memref<512x64xi32> to memref<512x64xi32> +// +// %conv_out = bufferization.to_tensor %2 : memref<512x64xi32> +// %rhs = bufferization.to_tensor %rhs_memref : memref<64x64xi32> +// %y = tensor.empty() : tensor<512x64xi32> +// %matmul = linalg.matmul ins(%conv_out, %rhs: tensor<512x64xi32>, tensor<64x64xi32>) +// outs(%y: tensor<512x64xi32>) -> tensor<512x64xi32> +// %materialize2 = bufferization.to_memref %matmul : memref<512x64xi32> +// memref.copy %materialize2, %2 : memref<512x64xi32> to memref<512x64xi32> +// return %c0_i32 : i32 +// } + +//Output after linking kernels +func.func @main_opt() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %rhs_memref = memref.get_global @rhs : memref<64x64xi32> + %4 = bufferization.to_tensor %0 : memref<515x67xi32> + %5 = bufferization.to_tensor %1 : memref<4x4xi32> + %x = tensor.empty() : tensor<512x64xi32> + %conv_out = bufferization.to_tensor %2 : memref<512x64xi32> + %rhs = bufferization.to_tensor %rhs_memref : memref<64x64xi32> + %y = tensor.empty() : tensor<512x64xi32> + %out = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%4, %5 : tensor<515x67xi32>, tensor<4x4xi32>) outs(%x : tensor<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } -> tensor<512x64xi32> + %matmul = linalg.matmul ins(%out, %rhs: tensor<512x64xi32>, tensor<64x64xi32>) + outs(%y: tensor<512x64xi32>) -> tensor<512x64xi32> + + %materialize2 = bufferization.to_memref %matmul : memref<512x64xi32> + memref.copy %materialize2, %2 : memref<512x64xi32> to memref<512x64xi32> + return %c0_i32 : i32 +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op) : + // Since the %arg2 handle is associated with both elementwise operations, + // we need to split it into two handles so we can target only the second + // elementwise operation. + %generic = transform.structured.match ops{["linalg.matmul","linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %conv, %mul = transform.split_handle %generic + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + + // The actual tiling transformation takes tile sizes as attributes. It + // produces a handle to the loop generated during tiling. + %tiled_mul, %loop = + transform.structured.tile_using_forall %mul tile_sizes [8, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // We can now fuse the other operations into the loop. Here, we fuse + // operations one by one. This requires the operation that is being fused to + // define the value used within the loop, so the order of such fusions is + // important. We could also use "transform.merge_handles" to obtain a single + // handle to all operations and give it to `fuse_into_containing_op` that + // would take care of the ordering in this case. + %conv_fused, %loop_0 = + transform.structured.fuse_into_containing_op %conv into %loop + : (!transform.any_op, !transform.any_op) + -> (!transform.any_op, !transform.any_op) + + + transform.yield +} + +// ----- \ No newline at end of file diff --git a/test/polygeist-opt/linalgraise.mlir b/test/polygeist-opt/linalgraise.mlir index e0ceffa1849c..0d6b0dd61fc0 100644 --- a/test/polygeist-opt/linalgraise.mlir +++ b/test/polygeist-opt/linalgraise.mlir @@ -1,44 +1,58 @@ -// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s +//// RUN: polygeist-opt --raise-affine-to-linalg --split-input-file %s | FileCheck %s +// +// module { +// func.func @main0(%12 : i1, %18 : memref<32xf32> ) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %19 = memref.alloca() : memref<32xf32> +// scf.if %12 { +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref<32xf32> +// affine.store %ld, %19[%arg4] : memref<32xf32> +// } +// } +// return +// } + + // func.func @main(%12 : i1, %14 : i32, %18 : memref ) { + // %c0 = arith.constant 0 : index + // %c4 = arith.constant 4 : index + // %c1 = arith.constant 1 : index + // %15 = arith.index_cast %14 : i32 to index + // %16 = arith.muli %15, %c4 : index + // %17 = arith.divui %16, %c4 : index + // %19 = memref.alloca(%17) : memref + // scf.if %12 { + // affine.for %arg4 = 0 to 17 { + // %ld = affine.load %18[%arg4] : memref + // affine.store %ld, %19[%arg4] : memref + // } + // } + // return + // } -module { - func.func @main(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - scf.if %12 { - affine.for %arg4 = 0 to %17 { - %ld = affine.load %18[%arg4] : memref - affine.store %ld, %19[%arg4] : memref - } - } - return - } + // func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { + // %c0 = arith.constant 0 : index + // %c4 = arith.constant 4 : index + // %c1 = arith.constant 1 : index + // %15 = arith.index_cast %14 : i32 to index + // %16 = arith.muli %15, %c4 : index + // %17 = arith.divui %16, %c4 : index + // %19 = memref.alloca(%17) : memref + // scf.if %12 { + // affine.for %arg4 = 0 to 17 { + // %ld = affine.load %18[3 * %arg4] : memref + // %ld2 = affine.load %18[0] : memref + // %fadd = arith.addf %ld, %ld2 : f32 + // affine.store %fadd, %19[%arg4 + 17] : memref + // } + // } + // return + // } - func.func @main2(%12 : i1, %14 : i32, %18 : memref ) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %15 = arith.index_cast %14 : i32 to index - %16 = arith.muli %15, %c4 : index - %17 = arith.divui %16, %c4 : index - %19 = memref.alloca(%17) : memref - scf.if %12 { - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[3 * %arg4] : memref - %ld2 = affine.load %18[0] : memref - %fadd = arith.addf %ld, %ld2 : f32 - affine.store %fadd, %19[%arg4 + 17] : memref - } - } - return - } - -} + // } // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { @@ -177,7 +191,7 @@ module @cond_arith{ } } -//reduction +//TODO: reduction module @reduction{ func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { %c0 = arith.constant 0 : index @@ -198,7 +212,53 @@ module @reduction{ } } -//Conditional store-1 +module @reduction_transformed{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + %sum_0 = arith.constant 0.0 : f32 + %alloca = memref.alloca() : memref<1xf32> + affine.store %sum_0, %alloca[0] : memref<1xf32> + affine.for %arg4 = 0 to 17 step 1 { + %iter_arg = affine.load %alloca[0] : memref<1xf32> + %ld1 = affine.load %18[%arg4] : memref + %sum_next = arith.addf %iter_arg, %ld1 : f32 + affine.store %sum_next, %alloca[0] : memref<1xf32> + affine.yield + } + %red = affine.load %alloca[0] : memref<1xf32> + affine.store %red, %19[0] : memref + return + } +} + +module @reduction_transformed_simplified{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %19 = memref.alloca(%17) : memref + %sum_0 = arith.constant 0.0 : f32 + affine.store %sum_0, %19[0] : memref + affine.for %arg4 = 0 to 17 step 1 { + %iter_arg = affine.load %19[0] : memref + %ld1 = affine.load %18[%arg4] : memref + %sum_next = arith.addf %iter_arg, %ld1 : f32 + affine.store %sum_next, %19[0] : memref + affine.yield + } + return + } +} +//TODO: Conditional store-1 module @cond_store_1 { func.func @main(%12 : i1, %14 : i32, %18 : memref ) { %c0 = arith.constant 0 : index @@ -219,7 +279,7 @@ module @cond_store_1 { } } -//Conditional store-2 +//TODO: Conditional store-2 module @cond_store_2{ func.func @main(%12 : i1, %14 : i32, %18 : memref ) { %c0 = arith.constant 0 : index @@ -242,8 +302,34 @@ module @cond_store_2{ } } -//Parallel for -module @parallel_for{ +// //Parallel for +// module @parallel_for{ +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg4 = 0 to 17 { +// %ld = affine.load %18[%arg4] : memref +// %mul = arith.mulf %ld, %ld : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg4] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// return +// } +// } + +//Fors inside for +module @for_within_for{ func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -251,25 +337,22 @@ module @parallel_for{ %15 = arith.index_cast %14 : i32 to index %16 = arith.muli %15, %c4 : index %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index %19 = memref.alloca(%17) : memref - affine.for %arg4 = 0 to 17 { - %ld = affine.load %18[%arg4] : memref - %mul = arith.mulf %ld, %ld : f32 - affine.store %mul, %19[%arg4] : memref - } - affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg4] : memref - %ld2 = affine.load %20[%arg4] : memref - %add = arith.addf %ld1, %ld2 : f32 - %mul = arith.mulf %add, %add : f32 - affine.store %mul, %19[%arg4] : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } } return } } //Fors inside for -module @for_within_for{ +module @for_within_for_2{ func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -281,7 +364,7 @@ module @for_within_for{ %19 = memref.alloca(%17) : memref affine.for %arg3 = 0 to 21 { affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref + %ld1 = affine.load %18[%arg3+2*%arg4] : memref %ld2 = affine.load %20[%arg4] : memref %mul = arith.mulf %ld1, %ld2 : f32 affine.store %mul, %19[%arg4] : memref @@ -291,8 +374,8 @@ module @for_within_for{ } } -//Parallel fors inside for -module @parallel_fors_inside_for { +//Fors inside for +module @for_within_for_3{ func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -300,19 +383,38 @@ module @parallel_fors_inside_for { %15 = arith.index_cast %14 : i32 to index %16 = arith.muli %15, %c4 : index %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index %19 = memref.alloca(%17) : memref - affine.for %arg3 = 0 to 17 { + affine.for %arg3 = 0 to 21 { affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref - %ld2 = affine.load %20[%arg4] : memref + %ld1 = affine.load %18[%arg3+2*%arg4] : memref + %ld2 = affine.load %18[%arg3] : memref + %ld3 = affine.load %20[%arg4] : memref %mul = arith.mulf %ld1, %ld2 : f32 - affine.store %mul, %19[%arg4] : memref + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref } + } + return + } +} + +//Fors inside for +module @for_within_for_4{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { affine.for %arg4 = 0 to 17 { - %ld1 = affine.load %18[%arg3] : memref + %ld1 = affine.load %18[%arg4+2*%arg3] : memref %ld2 = affine.load %20[%arg4] : memref - %add = arith.addf %ld1, %ld2 : f32 - %mul = arith.mulf %add, %add : f32 + %mul = arith.mulf %ld1, %ld2 : f32 affine.store %mul, %19[%arg4] : memref } } @@ -320,6 +422,229 @@ module @parallel_fors_inside_for { } } +//Fors no-loop dependency +module @for_no_loop_dependency{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 15 { + %ld1 = affine.load %18[0] : memref + affine.store %ld1, %19[0] : memref + } + return + } +} +//Fors no-loop dependency +module @for_2_levels_no_loop_dependency{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg4 = 0 to 17 { + affine.for %arg3 = 0 to 15 { + %ld1 = affine.load %18[%arg4] : memref + affine.store %ld1, %19[%arg4] : memref + } + } + return + } +} +//Fors inside for +module @for_3_levels_0{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 15 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg5] : memref + } + } + } + return + } +} + +//Fors inside for +module @for_3_levels_1{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg5 = 0 to 21 { + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + affine.store %mul, %19[%arg4] : memref + } + } + } + return + } +} + +//Fors inside for +module @for_3_levels_2{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref, %23 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %ld3 = affine.load %23[%arg5] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + } + return + } +} + +//Fors inside for +module @for_3_levels_3{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3] : memref + %ld2 = affine.load %20[%arg4] : memref + %ld3 = affine.load %20[%arg5] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + } + return + } +} + +//Fors inside for +module @for_3_levels_4{ + func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + %21 = arith.muli %16, %c4 : index + %19 = memref.alloca(%17) : memref + affine.for %arg3 = 0 to 21 { + affine.for %arg4 = 0 to 17 { + affine.for %arg5 = 0 to 21 { + %ld1 = affine.load %18[%arg3+4*%arg4+3] : memref + %ld2 = affine.load %20[7*%arg4+%arg5+2] : memref + %ld3 = affine.load %20[%arg5+2*%arg3] : memref + %mul = arith.mulf %ld1, %ld2 : f32 + %mul2 = arith.mulf %mul, %ld3 : f32 + affine.store %mul2, %19[%arg4] : memref + } + } + } + return + } +} + +//Intermediate raising +#map = affine_map<(d0)[s0] -> (s0)> +#map1 = affine_map<(d0) -> (d0)> +module @for_within_for2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.for %arg4 = 0 to 21 { + %3 = "polygeist.submap"(%arg2, %arg4, %c17) <{map = #map}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17) <{map = #map1}> : (memref, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17) <{map = #map1}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + } + return + } +} + +// //Parallel fors inside for +// module @parallel_fors_inside_for { +// func.func @main(%12 : i1, %14 : i32, %18 : memref, %20 : memref) { +// %c0 = arith.constant 0 : index +// %c4 = arith.constant 4 : index +// %c1 = arith.constant 1 : index +// %15 = arith.index_cast %14 : i32 to index +// %16 = arith.muli %15, %c4 : index +// %17 = arith.divui %16, %c4 : index +// %19 = memref.alloca(%17) : memref +// affine.for %arg3 = 0 to 17 { +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %mul = arith.mulf %ld1, %ld2 : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// affine.for %arg4 = 0 to 17 { +// %ld1 = affine.load %18[%arg3] : memref +// %ld2 = affine.load %20[%arg4] : memref +// %add = arith.addf %ld1, %ld2 : f32 +// %mul = arith.mulf %add, %add : f32 +// affine.store %mul, %19[%arg4] : memref +// } +// } +// return +// } +// } + //matrix-mul iter arg module @matmul_1 { memref.global @out : memref<32x8xi32> = uninitialized @@ -346,31 +671,31 @@ module @matmul_1 { } } -//matrix-mul alias issue -module @matmul_2 { - memref.global @out : memref<128x32xi32> = uninitialized - memref.global @im2 : memref<64x32xi32> = uninitialized - memref.global @im1 : memref<128x64xi32> = uninitialized - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { - %c0_i32 = arith.constant 0 : i32 - %0 = memref.get_global @im1 : memref<128x64xi32> - %1 = memref.get_global @im2 : memref<64x32xi32> - %2 = memref.get_global @out : memref<128x32xi32> - affine.for %arg0 = 0 to 128 { - affine.for %arg1 = 0 to 32 { - affine.for %arg2 = 0 to 64 { - %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> - %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> - %5 = arith.muli %3, %4 : i32 - %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> - %7 = arith.addi %6, %5 : i32 - affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> - } - } - } - return %c0_i32 : i32 - } -} +//matrix-mul extra load-store variant + module @matmul_2 { + memref.global @out : memref<128x32xi32> = uninitialized + memref.global @im2 : memref<64x32xi32> = uninitialized + memref.global @im1 : memref<128x64xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im1 : memref<128x64xi32> + %1 = memref.get_global @im2 : memref<64x32xi32> + %2 = memref.get_global @out : memref<128x32xi32> + affine.for %arg0 = 0 to 128 { + affine.for %arg1 = 0 to 32 { + affine.for %arg2 = 0 to 64 { + %3 = affine.load %0[%arg0, %arg2] : memref<128x64xi32> + %4 = affine.load %1[%arg2, %arg1] : memref<64x32xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<128x32xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<128x32xi32> + } + } + } + return %c0_i32 : i32 + } + } //conv (with inner loop accumulate) //How to deal with IR in outer loops as well? @@ -402,25 +727,519 @@ module @conv_1{ } } -//conv (direct store) -module @conv_2{ +module @conv_1_reduction_test{ memref.global @out : memref<512x64xi32> = uninitialized memref.global @filter : memref<4x4xi32> = uninitialized memref.global @im : memref<515x67xi32> = uninitialized - func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + func.func @main(%arg0 : index, %arg1 : index) -> i32 attributes {llvm.linkage = #llvm.linkage} { %c0_i32 = arith.constant 0 : i32 %0 = memref.get_global @im : memref<515x67xi32> - %1 = memref.get_global @out : memref<512x64xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = affine.for %arg2 = 0 to 4 iter_args(%arg3 = %c0_i32) -> (i32) { + %4 = affine.for %arg4 = 0 to 4 iter_args(%arg5 = %arg3) -> (i32) { + %5 = affine.load %0[%arg0 + %arg2, %arg1 + %arg4] : memref<515x67xi32> + %6 = affine.load %1[%arg2, %arg4] : memref<4x4xi32> + %7 = arith.muli %5, %6 : i32 + %8 = arith.addi %arg5, %7 : i32 + affine.yield %8 : i32 + } + affine.yield %4 : i32 + } + affine.store %3, %2[%arg0, %arg1] : memref<512x64xi32> + return %c0_i32 : i32 + } +} + +//conv (direct store) + module @conv_2 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 64 { + affine.for %arg2 = 0 to 4 { + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + } + } + } + } + return %c0_i32 : i32 + } + } + +//box_filter (direct store) + module @box_filter { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %2 = memref.get_global @out : memref<512x64xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 64 { + affine.for %arg2 = 0 to 4 { + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %3 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + } + } + } + } + return %c0_i32 : i32 + } + } + + module @conv_loop1_test { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<4x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<512x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<512x64xi32> + } + return %c0_i32 : i32 + } + } + + module @submap_test { + memref.global @out : memref<511x64xi32> = uninitialized + memref.global @filter : memref<5x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main(%arg0 : index, %arg1 : index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<5x4xi32> + %2 = memref.get_global @out : memref<511x64xi32> + affine.for %arg2 = 0 to 5 { + affine.for %arg3 = 0 to 4 { + %3 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> + %4 = affine.load %1[%arg2, %arg3] : memref<5x4xi32> + %5 = arith.muli %3, %4 : i32 + %6 = affine.load %2[%arg0, %arg1] : memref<511x64xi32> + %7 = arith.addi %6, %5 : i32 + affine.store %7, %2[%arg0, %arg1] : memref<511x64xi32> + } + } + return %c0_i32 : i32 + } + } + + +module @harris_score_1{ + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + affine.for %arg2 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + %gx = affine.load %0[%arg0, %arg1] : memref<516x516xi32> + %gy = affine.load %1[%arg0, %arg1] : memref<516x516xi32> + %11 = affine.load %2[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %12 = affine.load %3[%arg5 + %arg2 * 3] : memref<9xi32> + %13 = arith.muli %11, %12 : i32 + %14 = arith.addi %gx, %13 : i32 + %15 = affine.load %4[%arg5 + %arg2 * 3] : memref<9xi32> + %16 = arith.muli %11, %15 : i32 + %17 = arith.addi %gy, %16 : i32 + affine.store %14, %0[%arg0, %arg1] : memref<516x516xi32> + affine.store %17, %1[%arg0, %arg1] : memref<516x516xi32> + } + } + } + } + %5 = memref.get_global @img_ixx : memref<512x512xi32> + %6 = memref.get_global @img_iyy : memref<512x512xi32> + %7 = memref.get_global @img_ixy : memref<512x512xi32> affine.for %arg0 = 0 to 512 { - affine.for %arg1 = 0 to 64 { - affine.for %arg2 = 0 to 4 { - affine.for %arg3 = 0 to 4 { - %2 = affine.load %0[%arg0 + %arg2, %arg1 + %arg3] : memref<515x67xi32> - %3 = affine.load %1[%arg0, %arg1] : memref<512x64xi32> - %4 = arith.addi %3, %2 : i32 - affine.store %4, %1[%arg0, %arg1] : memref<512x64xi32> + affine.for %arg1 = 0 to 512 { + affine.for %arg2 = 0 to 5 { + affine.for %arg6 = 0 to 5 { + %ixx = affine.load %5[%arg0, %arg1] : memref<512x512xi32> + %iyy = affine.load %6[%arg0, %arg1] : memref<512x512xi32> + %ixy = affine.load %7[%arg0, %arg1] : memref<512x512xi32> + %11 = affine.load %0[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %12 = affine.load %1[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %13 = arith.muli %11, %11 : i32 + %14 = arith.addi %ixx, %13 : i32 + %15 = arith.muli %12, %12 : i32 + %16 = arith.addi %iyy, %15 : i32 + %17 = arith.muli %11, %12 : i32 + %18 = arith.addi %ixy, %17 : i32 + affine.store %14, %5[%arg0, %arg1] : memref<512x512xi32> + affine.store %16, %6[%arg0, %arg1] : memref<512x512xi32> + affine.store %18, %7[%arg0, %arg1] : memref<512x512xi32> + } + } + } + } + %8 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %9 = affine.load %5[%arg0, %arg1] : memref<512x512xi32> + %10 = affine.load %6[%arg0, %arg1] : memref<512x512xi32> + %11 = affine.load %7[%arg0, %arg1] : memref<512x512xi32> + %12 = arith.muli %9, %10 : i32 + %13 = arith.muli %11, %11 : i32 + %14 = arith.subi %12, %13 : i32 + %15 = arith.addi %9, %10 : i32 + %16 = arith.muli %15, %c4_i32 : i32 + %17 = arith.muli %16, %15 : i32 + %18 = arith.subi %14, %17 : i32 + affine.store %18, %8[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} + +module @harris_score_2 { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + %9:2 = affine.for %arg2 = 0 to 3 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32) -> (i32, i32) { + %10:2 = affine.for %arg5 = 0 to 3 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (i32, i32) { + %11 = affine.load %2[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %12 = affine.load %3[%arg5 + %arg2 * 3] : memref<9xi32> + %13 = arith.muli %11, %12 : i32 + %14 = arith.addi %arg7, %13 : i32 + %15 = affine.load %4[%arg5 + %arg2 * 3] : memref<9xi32> + %16 = arith.muli %11, %15 : i32 + %17 = arith.addi %arg6, %16 : i32 + affine.yield %17, %14 : i32, i32 + } + affine.yield %10#0, %10#1 : i32, i32 + } + affine.store %9#1, %0[%arg0, %arg1] : memref<516x516xi32> + affine.store %9#0, %1[%arg0, %arg1] : memref<516x516xi32> + } + } + %5 = memref.get_global @img_ixx : memref<512x512xi32> + %6 = memref.get_global @img_iyy : memref<512x512xi32> + %7 = memref.get_global @img_ixy : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %9:3 = affine.for %arg2 = 0 to 5 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32) -> (i32, i32, i32) { + %10:3 = affine.for %arg6 = 0 to 5 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (i32, i32, i32) { + %11 = affine.load %0[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %12 = affine.load %1[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %13 = arith.muli %11, %11 : i32 + %14 = arith.addi %arg9, %13 : i32 + %15 = arith.muli %12, %12 : i32 + %16 = arith.addi %arg8, %15 : i32 + %17 = arith.muli %11, %12 : i32 + %18 = arith.addi %arg7, %17 : i32 + affine.yield %18, %16, %14 : i32, i32, i32 + } + affine.yield %10#0, %10#1, %10#2 : i32, i32, i32 + } + affine.store %9#2, %5[%arg0, %arg1] : memref<512x512xi32> + affine.store %9#1, %6[%arg0, %arg1] : memref<512x512xi32> + affine.store %9#0, %7[%arg0, %arg1] : memref<512x512xi32> + } + } + %8 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %9 = affine.load %5[%arg0, %arg1] : memref<512x512xi32> + %10 = affine.load %6[%arg0, %arg1] : memref<512x512xi32> + %11 = affine.load %7[%arg0, %arg1] : memref<512x512xi32> + %12 = arith.muli %9, %10 : i32 + %13 = arith.muli %11, %11 : i32 + %14 = arith.subi %12, %13 : i32 + %15 = arith.addi %9, %10 : i32 + %16 = arith.muli %15, %c4_i32 : i32 + %17 = arith.muli %16, %15 : i32 + %18 = arith.subi %14, %17 : i32 + affine.store %18, %8[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} + +module @harris_score_local { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @coeffs_x : memref<9xi32> + %1 = memref.get_global @coeffs_y : memref<9xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + affine.for %arg2 = 0 to 3 { + affine.for %arg5 = 0 to 3 { + %gx = affine.load %alloca_3[%arg0, %arg1] : memref<516x516xi32> + %gy = affine.load %alloca_2[%arg0, %arg1] : memref<516x516xi32> + %5 = affine.load %alloca_4[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %6 = affine.load %0[%arg5 + %arg2 * 3] : memref<9xi32> + %7 = arith.muli %5, %6 : i32 + %8 = arith.addi %gx, %7 : i32 + %9 = affine.load %1[%arg5 + %arg2 * 3] : memref<9xi32> + %10 = arith.muli %5, %9 : i32 + %11 = arith.addi %gy, %10 : i32 + affine.store %8, %alloca_3[%arg0, %arg1] : memref<516x516xi32> + affine.store %11, %alloca_2[%arg0, %arg1] : memref<516x516xi32> + } + } + } + } + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %3:3 = affine.for %arg2 = 0 to 5 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32) -> (i32, i32, i32) { + %4:3 = affine.for %arg6 = 0 to 5 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (i32, i32, i32) { + %ixx = affine.load %alloca_1[%arg0, %arg1] : memref<512x512xi32> + %iyy = affine.load %alloca_0[%arg0, %arg1] : memref<512x512xi32> + %ixy = affine.load %alloca[%arg0, %arg1] : memref<512x512xi32> + %5 = affine.load %alloca_3[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %6 = affine.load %alloca_2[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %7 = arith.muli %5, %5 : i32 + %8 = arith.addi %arg9, %7 : i32 + %9 = arith.muli %6, %6 : i32 + %10 = arith.addi %arg8, %9 : i32 + %11 = arith.muli %5, %6 : i32 + %12 = arith.addi %arg7, %11 : i32 + affine.yield %12, %10, %8 : i32, i32, i32 + } + affine.yield %4#0, %4#1, %4#2 : i32, i32, i32 + } + affine.store %3#2, %alloca_1[%arg0, %arg1] : memref<512x512xi32> + affine.store %3#1, %alloca_0[%arg0, %arg1] : memref<512x512xi32> + affine.store %3#0, %alloca[%arg0, %arg1] : memref<512x512xi32> + } + } + %2 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %3 = affine.load %alloca_1[%arg0, %arg1] : memref<512x512xi32> + %4 = affine.load %alloca_0[%arg0, %arg1] : memref<512x512xi32> + %5 = affine.load %alloca[%arg0, %arg1] : memref<512x512xi32> + %6 = arith.muli %3, %4 : i32 + %7 = arith.muli %5, %5 : i32 + %8 = arith.subi %6, %7 : i32 + %9 = arith.addi %3, %4 : i32 + %10 = arith.muli %9, %c4_i32 : i32 + %11 = arith.muli %10, %9 : i32 + %12 = arith.subi %8, %11 : i32 + affine.store %12, %2[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} + +module @harris_score_2d_kernel { + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + %3:2 = affine.for %arg2 = 0 to 3 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32) -> (i32, i32) { + %4:2 = affine.for %arg5 = 0 to 3 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (i32, i32) { + %5 = affine.load %alloca_4[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %6 = affine.load %0[%arg2, %arg5] : memref<3x3xi32> + %7 = arith.muli %5, %6 : i32 + %8 = arith.addi %arg7, %7 : i32 + %9 = affine.load %1[%arg2, %arg5] : memref<3x3xi32> + %10 = arith.muli %5, %9 : i32 + %11 = arith.addi %arg6, %10 : i32 + affine.yield %11, %8 : i32, i32 + } + affine.yield %4#0, %4#1 : i32, i32 + } + affine.store %3#1, %alloca_3[%arg0, %arg1] : memref<516x516xi32> + affine.store %3#0, %alloca_2[%arg0, %arg1] : memref<516x516xi32> + } + } + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %3:3 = affine.for %arg2 = 0 to 5 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32) -> (i32, i32, i32) { + %4:3 = affine.for %arg6 = 0 to 5 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (i32, i32, i32) { + %5 = affine.load %alloca_3[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %6 = affine.load %alloca_2[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %7 = arith.muli %5, %5 : i32 + %8 = arith.addi %arg9, %7 : i32 + %9 = arith.muli %6, %6 : i32 + %10 = arith.addi %arg8, %9 : i32 + %11 = arith.muli %5, %6 : i32 + %12 = arith.addi %arg7, %11 : i32 + affine.yield %12, %10, %8 : i32, i32, i32 } + affine.yield %4#0, %4#1, %4#2 : i32, i32, i32 } + affine.store %3#2, %alloca_1[%arg0, %arg1] : memref<512x512xi32> + affine.store %3#1, %alloca_0[%arg0, %arg1] : memref<512x512xi32> + affine.store %3#0, %alloca[%arg0, %arg1] : memref<512x512xi32> + } + } + %2 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %3 = affine.load %alloca_1[%arg0, %arg1] : memref<512x512xi32> + %4 = affine.load %alloca_0[%arg0, %arg1] : memref<512x512xi32> + %5 = affine.load %alloca[%arg0, %arg1] : memref<512x512xi32> + %6 = arith.muli %3, %4 : i32 + %7 = arith.muli %5, %5 : i32 + %8 = arith.subi %6, %7 : i32 + %9 = arith.addi %3, %4 : i32 + %10 = arith.muli %9, %c4_i32 : i32 + %11 = arith.muli %10, %9 : i32 + %12 = arith.subi %8, %11 : i32 + affine.store %12, %2[%arg0, %arg1] : memref<512x512xi32> + } + } + return %c0_i32 : i32 + } +} + +module @harris_score_with_gradient_extra_kernel { + memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + affine.for %arg0 = 0 to 516 { + affine.for %arg1 = 0 to 516 { + %4:2 = affine.for %arg2 = 0 to 3 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32) -> (i32, i32) { + %5:2 = affine.for %arg5 = 0 to 3 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (i32, i32) { + %6 = affine.load %alloca_4[%arg0 + %arg2, %arg1 + %arg5] : memref<518x518xi32> + %7 = affine.load %0[%arg2, %arg5] : memref<3x3xi32> + %8 = arith.muli %6, %7 : i32 + %9 = arith.addi %arg7, %8 : i32 + %10 = affine.load %1[%arg2, %arg5] : memref<3x3xi32> + %11 = arith.muli %6, %10 : i32 + %12 = arith.addi %arg6, %11 : i32 + affine.yield %12, %9 : i32, i32 + } + affine.yield %5#0, %5#1 : i32, i32 + } + affine.store %4#1, %alloca_3[%arg0, %arg1] : memref<516x516xi32> + affine.store %4#0, %alloca_2[%arg0, %arg1] : memref<516x516xi32> + } + } + %2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %4:3 = affine.for %arg2 = 0 to 5 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32, %arg5 = %c0_i32) -> (i32, i32, i32) { + %5:3 = affine.for %arg6 = 0 to 5 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (i32, i32, i32) { + %6 = affine.load %alloca_3[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %7 = affine.load %alloca_2[%arg0 + %arg2, %arg1 + %arg6] : memref<516x516xi32> + %8 = arith.muli %6, %6 : i32 + %9 = affine.load %2[%arg2, %arg6] : memref<5x5xi32> + %10 = arith.muli %8, %9 : i32 + %11 = arith.addi %arg9, %10 : i32 + %12 = arith.muli %7, %7 : i32 + %13 = arith.muli %12, %9 : i32 + %14 = arith.addi %arg8, %13 : i32 + %15 = arith.muli %6, %7 : i32 + %16 = arith.muli %15, %9 : i32 + %17 = arith.addi %arg7, %16 : i32 + affine.yield %17, %14, %11 : i32, i32, i32 + } + affine.yield %5#0, %5#1, %5#2 : i32, i32, i32 + } + affine.store %4#2, %alloca_1[%arg0, %arg1] : memref<512x512xi32> + affine.store %4#1, %alloca_0[%arg0, %arg1] : memref<512x512xi32> + affine.store %4#0, %alloca[%arg0, %arg1] : memref<512x512xi32> + } + } + %3 = memref.get_global @score : memref<512x512xi32> + affine.for %arg0 = 0 to 512 { + affine.for %arg1 = 0 to 512 { + %4 = affine.load %alloca_1[%arg0, %arg1] : memref<512x512xi32> + %5 = affine.load %alloca_0[%arg0, %arg1] : memref<512x512xi32> + %6 = affine.load %alloca[%arg0, %arg1] : memref<512x512xi32> + %7 = arith.muli %4, %5 : i32 + %8 = arith.muli %6, %6 : i32 + %9 = arith.subi %7, %8 : i32 + %10 = arith.addi %4, %5 : i32 + %11 = arith.muli %10, %c4_i32 : i32 + %12 = arith.muli %11, %10 : i32 + %13 = arith.subi %9, %12 : i32 + affine.store %13, %3[%arg0, %arg1] : memref<512x512xi32> } } return %c0_i32 : i32 diff --git a/test/polygeist-opt/raised_with_submap.mlir b/test/polygeist-opt/raised_with_submap.mlir new file mode 100644 index 000000000000..f126b738d0f1 --- /dev/null +++ b/test/polygeist-opt/raised_with_submap.mlir @@ -0,0 +1,1097 @@ +#map = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0) -> (d0 * 3)> +#map2 = affine_map<(d0)[s0] -> (s0)> +#map3 = affine_map<(d0) -> (0)> +#map4 = affine_map<(d0, d1) -> (d1)> +#map5 = affine_map<(d0, d1) -> (d0)> +#map6 = affine_map<(d0, d1) -> (d0, d1)> +#map7 = affine_map<(d0, d1) -> (d0 * 2 + d1)> +#map8 = affine_map<(d0, d1) -> (d0 + d1 * 2)> +#map9 = affine_map<(d0, d1, d2) -> (d2)> +#map10 = affine_map<(d0, d1, d2) -> (d1)> +#map11 = affine_map<(d0, d1, d2) -> (d0)> +#map12 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map13 = affine_map<(d0, d1, d2) -> (d1 * 4 + d2 + 3)> +#map14 = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + 2)> +#map15 = affine_map<(d0, d1, d2) -> (d0 + d2 * 2)> +#map16 = affine_map<(d0, d1, d2) -> (d2, d0)> +#map17 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map18 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map19 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> +#map20 = affine_map<(d0, d1, d2, d3) -> (d1, d0)> +#map21 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +#map22 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map23 = affine_map<(d0, d1)[s0, s1] -> (d1 + s0, d0 + s1)> +#map24 = affine_map<(d0, d1) -> (d1, d0)> +#map25 = affine_map<(d0, d1)[s0, s1] -> (s0, s1)> +#map26 = affine_map<(d0)[s0, s1, s2] -> (s0 + s1, d0 + s2)> +#map27 = affine_map<(d0)[s0] -> (s0, d0)> +#map28 = affine_map<(d0)[s0, s1] -> (s0, s1)> +#map29 = affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 3)> +module { + module @constant_access { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 4.000000e+00 : f32 + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = arith.mulf %in, %cst : f32 + linalg.yield %5 : f32 + } + return + } + } +// module @constant_mem_access { +// func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { +// %c13 = arith.constant 13 : index +// %c4 = arith.constant 4 : index +// %0 = arith.index_cast %arg1 : i32 to index +// %1 = arith.muli %0, %c4 : index +// %2 = arith.divui %1, %c4 : index +// %alloca = memref.alloca(%2) : memref +// %3 = "polygeist.submap"(%arg2, %c13) <{map = #map1}> : (memref, index) -> memref +// %4 = "polygeist.submap"(%arg2, %c4, %c13) <{map = #map2}> : (memref, index, index) -> memref +// %5 = "polygeist.submap"(%alloca, %c13) <{map = #map}> : (memref, index) -> memref +// linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { +// ^bb0(%in: f32, %in_0: f32, %out: f32): +// %6 = arith.mulf %in, %in_0 : f32 +// linalg.yield %6 : f32 +// } +// return +// } +// } + module @no_if { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + return + } + } + module @arith_mul { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = arith.mulf %in, %in : f32 + linalg.yield %5 : f32 + } + return + } + } + module @arith_add { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17) <{map = #map}> : (memref, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.addf %in, %in_0 : f32 + %7 = arith.mulf %6, %6 : f32 + linalg.yield %7 : f32 + } + return + } + } + module @cond_arith { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = scf.if %arg0 -> (f32) { + %6 = arith.mulf %in, %in : f32 + scf.yield %6 : f32 + } else { + scf.yield %in : f32 + } + linalg.yield %5 : f32 + } + return + } + } + module @reduction { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map3}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %out, %in : f32 + linalg.yield %5 : f32 + } + return + } + } + module @reduction_transformed { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %cst = arith.constant 0.000000e+00 : f32 + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %alloca_0 = memref.alloca() : memref<1xf32> + affine.store %cst, %alloca_0[0] : memref<1xf32> + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca_0, %c17) <{map = #map3}> : (memref<1xf32>, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %6 = arith.addf %out, %in : f32 + linalg.yield %6 : f32 + } + %5 = affine.load %alloca_0[0] : memref<1xf32> + affine.store %5, %alloca[0] : memref + return + } + } + module @reduction_transformed_simplified { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c17 = arith.constant 17 : index + %cst = arith.constant 0.000000e+00 : f32 + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.store %cst, %alloca[0] : memref + %3 = "polygeist.submap"(%arg2, %c17) <{map = #map}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c17) <{map = #map3}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %out, %in : f32 + linalg.yield %5 : f32 + } + return + } + } + module @cond_store_1 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.for %arg3 = 0 to 17 { + %3 = affine.load %arg2[%arg3] : memref + %4 = arith.mulf %3, %3 : f32 + scf.if %arg0 { + affine.store %4, %alloca[%arg3] : memref + } + } + return + } + } + module @cond_store_2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref) { + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + affine.for %arg3 = 0 to 17 { + %3 = affine.load %arg2[%arg3] : memref + scf.if %arg0 { + %4 = arith.mulf %3, %3 : f32 + affine.store %4, %alloca[%arg3] : memref + } else { + affine.store %3, %alloca[%arg3] : memref + } + } + return + } + } + module @for_within_for { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map4}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_within_for_2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map7}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_within_for_3 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map7}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map4}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %6 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4, %5 : memref, memref, memref) outs(%6 : memref) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %7 = arith.mulf %in, %in_0 : f32 + %8 = arith.mulf %7, %in_1 : f32 + linalg.yield %8 : f32 + } + return + } + } + module @for_within_for_4 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map8}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_no_loop_dependency { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c15 = arith.constant 15 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c15) <{map = #map3}> : (memref, index) -> memref + %4 = "polygeist.submap"(%alloca, %c15) <{map = #map3}> : (memref, index) -> memref + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + return + } + } + module @for_2_levels_no_loop_dependency { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c17 = arith.constant 17 : index + %c15 = arith.constant 15 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c15, %c17) <{map = #map4}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%alloca, %c15, %c17) <{map = #map4}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "reduction"]} ins(%3 : memref) outs(%4 : memref) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + return + } + } + module @for_3_levels_0 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c15 = arith.constant 15 : index + %c17 = arith.constant 17 : index + %c21 = arith.constant 21 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c21, %c17, %c15) <{map = #map9}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c21, %c17, %c15) <{map = #map10}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c21, %c17, %c15) <{map = #map11}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12], iterator_types = ["reduction", "reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_3_levels_1 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21, %c21) <{map = #map11}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21, %c21) <{map = #map11}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12], iterator_types = ["reduction", "reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @for_3_levels_2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref, %arg4: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c21, %c17, %c21) <{map = #map9}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%arg4, %c21, %c17, %c21) <{map = #map11}> : (memref, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12, #map12], iterator_types = ["reduction", "parallel", "reduction"]} ins(%3, %4, %5 : memref, memref, memref) outs(%6 : memref) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %7 = arith.mulf %in, %in_0 : f32 + %8 = arith.mulf %7, %in_1 : f32 + linalg.yield %8 : f32 + } + return + } + } + module @for_3_levels_3 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c21, %c17, %c21) <{map = #map9}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map11}> : (memref, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12, #map12], iterator_types = ["reduction", "parallel", "reduction"]} ins(%3, %4, %5 : memref, memref, memref) outs(%6 : memref) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %7 = arith.mulf %in, %in_0 : f32 + %8 = arith.mulf %7, %in_1 : f32 + linalg.yield %8 : f32 + } + return + } + } + module @for_3_levels_4 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c21, %c17, %c21) <{map = #map13}> : (memref, index, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map14}> : (memref, index, index, index) -> memref + %5 = "polygeist.submap"(%arg3, %c21, %c17, %c21) <{map = #map15}> : (memref, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca, %c21, %c17, %c21) <{map = #map10}> : (memref, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12, #map12], iterator_types = ["reduction", "parallel", "reduction"]} ins(%3, %4, %5 : memref, memref, memref) outs(%6 : memref) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %7 = arith.mulf %in, %in_0 : f32 + %8 = arith.mulf %7, %in_1 : f32 + linalg.yield %8 : f32 + } + return + } + } + module @for_within_for2 { + func.func @main(%arg0: i1, %arg1: i32, %arg2: memref, %arg3: memref) { + %c21 = arith.constant 21 : index + %c17 = arith.constant 17 : index + %c4 = arith.constant 4 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.muli %0, %c4 : index + %2 = arith.divui %1, %c4 : index + %alloca = memref.alloca(%2) : memref + %3 = "polygeist.submap"(%arg2, %c17, %c21) <{map = #map4}> : (memref, index, index) -> memref + %4 = "polygeist.submap"(%arg3, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + %5 = "polygeist.submap"(%alloca, %c17, %c21) <{map = #map5}> : (memref, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "parallel"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %6 = arith.mulf %in, %in_0 : f32 + linalg.yield %6 : f32 + } + return + } + } + module @matmul_1 { + memref.global @out : memref<32x8xi32> = uninitialized + memref.global @im2 : memref<8x8xi32> = uninitialized + memref.global @im1 : memref<32x8xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im1 : memref<32x8xi32> + %1 = memref.get_global @im2 : memref<8x8xi32> + %2 = memref.get_global @out : memref<32x8xi32> + %3 = "polygeist.submap"(%0, %c8, %c8, %c32) <{map = #map16}> : (memref<32x8xi32>, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c8, %c8, %c32) <{map = #map17}> : (memref<8x8xi32>, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c8, %c8, %c32) <{map = #map18}> : (memref<32x8xi32>, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @matmul_2 { + memref.global @out : memref<128x32xi32> = uninitialized + memref.global @im2 : memref<64x32xi32> = uninitialized + memref.global @im1 : memref<128x64xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im1 : memref<128x64xi32> + %1 = memref.get_global @im2 : memref<64x32xi32> + %2 = memref.get_global @out : memref<128x32xi32> + %3 = "polygeist.submap"(%0, %c64, %c32, %c128) <{map = #map16}> : (memref<128x64xi32>, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c64, %c32, %c128) <{map = #map17}> : (memref<64x32xi32>, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c64, %c32, %c128) <{map = #map18}> : (memref<128x32xi32>, index, index, index) -> memref + linalg.generic {indexing_maps = [#map12, #map12, #map12], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @conv_1 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = "polygeist.submap"(%0, %c4, %c4, %c64, %c512) <{map = #map19}> : (memref<515x67xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c4, %c4, %c64, %c512) <{map = #map20}> : (memref<4x4xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c4, %c4, %c64, %c512) <{map = #map21}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + // module @conv_1_reduction_test { + // memref.global @out : memref<512x64xi32> = uninitialized + // memref.global @filter : memref<4x4xi32> = uninitialized + // memref.global @im : memref<515x67xi32> = uninitialized + // func.func @main(%arg0: index, %arg1: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { + // %c4 = arith.constant 4 : index + // %c0_i32 = arith.constant 0 : i32 + // %0 = memref.get_global @im : memref<515x67xi32> + // %1 = memref.get_global @filter : memref<4x4xi32> + // %2 = memref.get_global @out : memref<512x64xi32> + // %3 = "polygeist.submap"(%0, %arg0, %arg1, %c4, %c4) <{map = #map23}> : (memref<515x67xi32>, index, index, index, index) -> memref + // %4 = "polygeist.submap"(%1, %c4, %c4) <{map = #map24}> : (memref<4x4xi32>, index, index) -> memref + // %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4, %c4) <{map = #map25}> : (memref<512x64xi32>, index, index, index, index) -> memref + // linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + // ^bb0(%in: i32, %in_0: i32, %out: i32): + // %6 = arith.muli %in, %in_0 : i32 + // %7 = arith.addi %out, %6 : i32 + // linalg.yield %7 : i32 + // } + // return %c0_i32 : i32 + // } + // } + module @conv_2 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = "polygeist.submap"(%0, %c4, %c4, %c64, %c512) <{map = #map19}> : (memref<515x67xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c4, %c4, %c64, %c512) <{map = #map20}> : (memref<4x4xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c4, %c4, %c64, %c512) <{map = #map21}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } + } + module @box_filter { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @out : memref<512x64xi32> + %2 = "polygeist.submap"(%0, %c4, %c4, %c64, %c512) <{map = #map19}> : (memref<515x67xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%1, %c4, %c4, %c64, %c512) <{map = #map21}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2 : memref) outs(%3 : memref) { + ^bb0(%in: i32, %out: i32): + %4 = arith.addi %out, %in : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } + } +// module @conv_loop1_test { +// memref.global @out : memref<512x64xi32> = uninitialized +// memref.global @filter : memref<4x4xi32> = uninitialized +// memref.global @im : memref<515x67xi32> = uninitialized +// func.func @main(%arg0: index, %arg1: index, %arg2: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c4 = arith.constant 4 : index +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im : memref<515x67xi32> +// %1 = memref.get_global @filter : memref<4x4xi32> +// %2 = memref.get_global @out : memref<512x64xi32> +// %3 = "polygeist.submap"(%0, %arg0, %arg2, %arg1, %c4) <{map = #map26}> : (memref<515x67xi32>, index, index, index, index) -> memref +// %4 = "polygeist.submap"(%1, %arg2, %c4) <{map = #map27}> : (memref<4x4xi32>, index, index) -> memref +// %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4) <{map = #map28}> : (memref<512x64xi32>, index, index, index) -> memref +// linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { +// ^bb0(%in: i32, %in_0: i32, %out: i32): +// %6 = arith.muli %in, %in_0 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7 : i32 +// } +// return %c0_i32 : i32 +// } +// } +// module @submap_test { +// memref.global @out : memref<511x64xi32> = uninitialized +// memref.global @filter : memref<5x4xi32> = uninitialized +// memref.global @im : memref<515x67xi32> = uninitialized +// func.func @main(%arg0: index, %arg1: index) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// %c5 = arith.constant 5 : index +// %c4 = arith.constant 4 : index +// %c0_i32 = arith.constant 0 : i32 +// %0 = memref.get_global @im : memref<515x67xi32> +// %1 = memref.get_global @filter : memref<5x4xi32> +// %2 = memref.get_global @out : memref<511x64xi32> +// %3 = "polygeist.submap"(%0, %arg0, %arg1, %c4, %c5) <{map = #map23}> : (memref<515x67xi32>, index, index, index, index) -> memref +// %4 = "polygeist.submap"(%1, %c4, %c5) <{map = #map24}> : (memref<5x4xi32>, index, index) -> memref +// %5 = "polygeist.submap"(%2, %arg0, %arg1, %c4, %c5) <{map = #map25}> : (memref<511x64xi32>, index, index, index, index) -> memref +// linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["reduction", "reduction"]} ins(%3, %4 : memref, memref) outs(%5 : memref) { +// ^bb0(%in: i32, %in_0: i32, %out: i32): +// %6 = arith.muli %in, %in_0 : i32 +// %7 = arith.addi %out, %6 : i32 +// linalg.yield %7 : i32 +// } +// return %c0_i32 : i32 +// } +// } + module @harris_score_1 { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + %5 = "polygeist.submap"(%2, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%3, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %7 = "polygeist.submap"(%4, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%5, %6, %7 : memref, memref, memref) outs(%8, %9 : memref, memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32, %out_2: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.addi %out, %23 : i32 + %25 = arith.muli %in, %in_1 : i32 + %26 = arith.addi %out_2, %25 : i32 + linalg.yield %24, %26 : i32, i32 + } + %10 = memref.get_global @img_ixx : memref<512x512xi32> + %11 = memref.get_global @img_iyy : memref<512x512xi32> + %12 = memref.get_global @img_ixy : memref<512x512xi32> + %13 = "polygeist.submap"(%0, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %14 = "polygeist.submap"(%1, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %15 = "polygeist.submap"(%10, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %16 = "polygeist.submap"(%11, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %17 = "polygeist.submap"(%12, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%13, %14 : memref, memref) outs(%15, %16, %17 : memref, memref, memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32, %out_1: i32, %out_2: i32): + %23 = arith.muli %in, %in : i32 + %24 = arith.addi %out, %23 : i32 + %25 = arith.muli %in_0, %in_0 : i32 + %26 = arith.addi %out_1, %25 : i32 + %27 = arith.muli %in, %in_0 : i32 + %28 = arith.addi %out_2, %27 : i32 + linalg.yield %24, %26, %28 : i32, i32, i32 + } + %18 = memref.get_global @score : memref<512x512xi32> + %19 = "polygeist.submap"(%10, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %20 = "polygeist.submap"(%11, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %21 = "polygeist.submap"(%12, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %22 = "polygeist.submap"(%18, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%19, %20, %21 : memref, memref, memref) outs(%22 : memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.muli %in_1, %in_1 : i32 + %25 = arith.subi %23, %24 : i32 + %26 = arith.addi %in, %in_0 : i32 + %27 = arith.muli %26, %c4_i32 : i32 + %28 = arith.muli %27, %26 : i32 + %29 = arith.subi %25, %28 : i32 + linalg.yield %29 : i32 + } + return %c0_i32 : i32 + } + } + module @harris_score_2 { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + %5 = "polygeist.submap"(%2, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%3, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %7 = "polygeist.submap"(%4, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%5, %6, %7 : memref, memref, memref) outs(%8, %9 : memref, memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32, %out_2: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.addi %out_2, %23 : i32 + %25 = arith.muli %in, %in_1 : i32 + %26 = arith.addi %out, %25 : i32 + linalg.yield %26, %24 : i32, i32 + } + %10 = memref.get_global @img_ixx : memref<512x512xi32> + %11 = memref.get_global @img_iyy : memref<512x512xi32> + %12 = memref.get_global @img_ixy : memref<512x512xi32> + %13 = "polygeist.submap"(%0, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %14 = "polygeist.submap"(%1, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %15 = "polygeist.submap"(%12, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %16 = "polygeist.submap"(%11, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %17 = "polygeist.submap"(%10, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%13, %14 : memref, memref) outs(%15, %16, %17 : memref, memref, memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32, %out_1: i32, %out_2: i32): + %23 = arith.muli %in, %in : i32 + %24 = arith.addi %out_2, %23 : i32 + %25 = arith.muli %in_0, %in_0 : i32 + %26 = arith.addi %out_1, %25 : i32 + %27 = arith.muli %in, %in_0 : i32 + %28 = arith.addi %out, %27 : i32 + linalg.yield %28, %26, %24 : i32, i32, i32 + } + %18 = memref.get_global @score : memref<512x512xi32> + %19 = "polygeist.submap"(%10, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %20 = "polygeist.submap"(%11, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %21 = "polygeist.submap"(%12, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %22 = "polygeist.submap"(%18, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%19, %20, %21 : memref, memref, memref) outs(%22 : memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.muli %in_1, %in_1 : i32 + %25 = arith.subi %23, %24 : i32 + %26 = arith.addi %in, %in_0 : i32 + %27 = arith.muli %26, %c4_i32 : i32 + %28 = arith.muli %27, %26 : i32 + %29 = arith.subi %25, %28 : i32 + linalg.yield %29 : i32 + } + return %c0_i32 : i32 + } + } + module @harris_score_local { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @coeffs_x : memref<9xi32> + %1 = memref.get_global @coeffs_y : memref<9xi32> + %2 = "polygeist.submap"(%alloca_4, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca_3, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca_2, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2, %3, %4 : memref, memref, memref) outs(%5, %6 : memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.addi %out, %17 : i32 + %19 = arith.muli %in, %in_6 : i32 + %20 = arith.addi %out_7, %19 : i32 + linalg.yield %18, %20 : i32, i32 + } + %7 = "polygeist.submap"(%alloca_3, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%alloca_2, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%alloca, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %10 = "polygeist.submap"(%alloca_0, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %11 = "polygeist.submap"(%alloca_1, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%7, %8 : memref, memref) outs(%9, %10, %11 : memref, memref, memref) { + ^bb0(%in: i32, %in_5: i32, %out: i32, %out_6: i32, %out_7: i32): + %17 = arith.muli %in, %in : i32 + %18 = arith.addi %out_7, %17 : i32 + %19 = arith.muli %in_5, %in_5 : i32 + %20 = arith.addi %out_6, %19 : i32 + %21 = arith.muli %in, %in_5 : i32 + %22 = arith.addi %out, %21 : i32 + linalg.yield %22, %20, %18 : i32, i32, i32 + } + %12 = memref.get_global @score : memref<512x512xi32> + %13 = "polygeist.submap"(%alloca_1, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %14 = "polygeist.submap"(%alloca_0, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %15 = "polygeist.submap"(%alloca, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %16 = "polygeist.submap"(%12, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%13, %14, %15 : memref, memref, memref) outs(%16 : memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.muli %in_6, %in_6 : i32 + %19 = arith.subi %17, %18 : i32 + %20 = arith.addi %in, %in_5 : i32 + %21 = arith.muli %20, %c4_i32 : i32 + %22 = arith.muli %21, %20 : i32 + %23 = arith.subi %19, %22 : i32 + linalg.yield %23 : i32 + } + return %c0_i32 : i32 + } + } +} + +module @harris_score_2d_kernel { + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + %2 = "polygeist.submap"(%alloca_4, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca_2, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca_3, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2, %3, %4 : memref, memref, memref) outs(%5, %6 : memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.addi %out_7, %17 : i32 + %19 = arith.muli %in, %in_6 : i32 + %20 = arith.addi %out, %19 : i32 + linalg.yield %20, %18 : i32, i32 + } + %7 = "polygeist.submap"(%alloca_3, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%alloca_2, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%alloca, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %10 = "polygeist.submap"(%alloca_0, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %11 = "polygeist.submap"(%alloca_1, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%7, %8 : memref, memref) outs(%9, %10, %11 : memref, memref, memref) { + ^bb0(%in: i32, %in_5: i32, %out: i32, %out_6: i32, %out_7: i32): + %17 = arith.muli %in, %in : i32 + %18 = arith.addi %out_7, %17 : i32 + %19 = arith.muli %in_5, %in_5 : i32 + %20 = arith.addi %out_6, %19 : i32 + %21 = arith.muli %in, %in_5 : i32 + %22 = arith.addi %out, %21 : i32 + linalg.yield %22, %20, %18 : i32, i32, i32 + } + %12 = memref.get_global @score : memref<512x512xi32> + %13 = "polygeist.submap"(%alloca_1, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %14 = "polygeist.submap"(%alloca_0, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %15 = "polygeist.submap"(%alloca, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %16 = "polygeist.submap"(%12, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%13, %14, %15 : memref, memref, memref) outs(%16 : memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.muli %in_6, %in_6 : i32 + %19 = arith.subi %17, %18 : i32 + %20 = arith.addi %in, %in_5 : i32 + %21 = arith.muli %20, %c4_i32 : i32 + %22 = arith.muli %21, %20 : i32 + %23 = arith.subi %19, %22 : i32 + linalg.yield %23 : i32 + } + return %c0_i32 : i32 + } +} + +module @harris_score_gradient_1d_kernel { + memref.global @coeffs_y : memref<9xi32> = dense<[-3, -10, -3, 0, 0, 0, 3, 10, 3]> + memref.global @coeffs_x : memref<9xi32> = dense<[-3, 0, 3, -10, 0, 10, -3, 0, 3]> + memref.global @score : memref<512x512xi32> = uninitialized + memref.global @img_ixy : memref<512x512xi32> = uninitialized + memref.global @img_iyy : memref<512x512xi32> = uninitialized + memref.global @img_ixx : memref<512x512xi32> = uninitialized + memref.global @img_in : memref<518x518xi32> = uninitialized + memref.global @img_gy : memref<516x516xi32> = uninitialized + memref.global @img_gx : memref<516x516xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @img_gx : memref<516x516xi32> + %1 = memref.get_global @img_gy : memref<516x516xi32> + %2 = memref.get_global @img_in : memref<518x518xi32> + %3 = memref.get_global @coeffs_x : memref<9xi32> + %4 = memref.get_global @coeffs_y : memref<9xi32> + %5 = "polygeist.submap"(%2, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%3, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %7 = "polygeist.submap"(%4, %c3, %c3, %c516, %c516) <{map = #map29}> : (memref<9xi32>, index, index, index, index) -> memref + %8 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%5, %6, %7 : memref, memref, memref) outs(%8, %9 : memref, memref) { + ^bb0(%in: i32, %in_0: i32, %in_1: i32, %out: i32, %out_2: i32): + %23 = arith.muli %in, %in_0 : i32 + %24 = arith.addi %out, %23 : i32 + %25 = arith.muli %in, %in_1 : i32 + %26 = arith.addi %out_2, %25 : i32 + linalg.yield %24, %26 : i32, i32 + } + return %c0_i32 : i32 + } +} + +module @harris_score_gradient_2d_kernel { + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + %2 = "polygeist.submap"(%alloca_4, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca_2, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca_3, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2, %3, %4 : memref, memref, memref) outs(%5, %6 : memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %17 = arith.muli %in, %in_5 : i32 + %18 = arith.addi %out_7, %17 : i32 + %19 = arith.muli %in, %in_6 : i32 + %20 = arith.addi %out, %19 : i32 + linalg.yield %20, %18 : i32, i32 + } + return %c0_i32 : i32 + } +} + +module @harris_score_with_gradient_extra_kernel { + memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + memref.global @score : memref<512x512xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c516 = arith.constant 516 : index + %c3 = arith.constant 3 : index + %c512 = arith.constant 512 : index + %c5 = arith.constant 5 : index + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + %alloca_4 = memref.alloca() : memref<518x518xi32> + %0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + %1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + %2 = "polygeist.submap"(%alloca_4, %c3, %c3, %c516, %c516) <{map = #map19}> : (memref<518x518xi32>, index, index, index, index) -> memref + %3 = "polygeist.submap"(%0, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c3, %c3, %c516, %c516) <{map = #map20}> : (memref<3x3xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%alloca_2, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + %6 = "polygeist.submap"(%alloca_3, %c3, %c3, %c516, %c516) <{map = #map21}> : (memref<516x516xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%2, %3, %4 : memref, memref, memref) outs(%5, %6 : memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %19 = arith.muli %in, %in_5 : i32 + %20 = arith.addi %out_7, %19 : i32 + %21 = arith.muli %in, %in_6 : i32 + %22 = arith.addi %out, %21 : i32 + linalg.yield %22, %20 : i32, i32 + } + %7 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + %8 = "polygeist.submap"(%alloca_3, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %9 = "polygeist.submap"(%alloca_2, %c5, %c5, %c512, %c512) <{map = #map19}> : (memref<516x516xi32>, index, index, index, index) -> memref + %10 = "polygeist.submap"(%7, %c5, %c5, %c512, %c512) <{map = #map20}> : (memref<5x5xi32>, index, index, index, index) -> memref + %11 = "polygeist.submap"(%alloca, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %12 = "polygeist.submap"(%alloca_0, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + %13 = "polygeist.submap"(%alloca_1, %c5, %c5, %c512, %c512) <{map = #map21}> : (memref<512x512xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%8, %9, %10 : memref, memref, memref) outs(%11, %12, %13 : memref, memref, memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): + %19 = arith.muli %in, %in : i32 + %20 = arith.muli %19, %in_6 : i32 + %21 = arith.addi %out_8, %20 : i32 + %22 = arith.muli %in_5, %in_5 : i32 + %23 = arith.muli %22, %in_6 : i32 + %24 = arith.addi %out_7, %23 : i32 + %25 = arith.muli %in, %in_5 : i32 + %26 = arith.muli %25, %in_6 : i32 + %27 = arith.addi %out, %26 : i32 + linalg.yield %27, %24, %21 : i32, i32, i32 + } + %14 = memref.get_global @score : memref<512x512xi32> + %15 = "polygeist.submap"(%alloca_1, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %16 = "polygeist.submap"(%alloca_0, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %17 = "polygeist.submap"(%alloca, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + %18 = "polygeist.submap"(%14, %c512, %c512) <{map = #map24}> : (memref<512x512xi32>, index, index) -> memref + linalg.generic {indexing_maps = [#map6, #map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%15, %16, %17 : memref, memref, memref) outs(%18 : memref) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %19 = arith.muli %in, %in_5 : i32 + %20 = arith.muli %in_6, %in_6 : i32 + %21 = arith.subi %19, %20 : i32 + %22 = arith.addi %in, %in_5 : i32 + %23 = arith.muli %22, %c4_i32 : i32 + %24 = arith.muli %23, %22 : i32 + %25 = arith.subi %21, %24 : i32 + linalg.yield %25 : i32 + } + return %c0_i32 : i32 + } +} diff --git a/test/polygeist-opt/submapcanonicalize.mlir b/test/polygeist-opt/submapcanonicalize.mlir new file mode 100644 index 000000000000..21f3e72fb5a1 --- /dev/null +++ b/test/polygeist-opt/submapcanonicalize.mlir @@ -0,0 +1,71 @@ +// RUN: polygeist-opt -canonicalize %s | FileCheck %s +#map = affine_map<(d0)[s0, s1] -> (d0 * s0, d0 * s1)> +module @submap_to_load__store{ + func.func private @use(i32) + func.func @f(%arg0: memref, %arg1 : index, %arg2 : index, %arg3 : index) { + + %submap = "polygeist.submap"(%arg0, %arg1, %arg2) <{map = #map}> : (memref, index, index) -> memref + + affine.for %arg4 = 0 to 10 { + %l = affine.load %submap[5 + %arg4 + symbol(%arg3)] : memref + func.call @use(%l) : (i32) -> () + affine.yield + } + return + } + + func.func @g(%arg0: memref, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : i32) { + %submap = "polygeist.submap"(%arg0, %arg1, %arg2) <{map = #map}> : (memref, index, index) -> memref + affine.for %arg5 = 0 to 10 { + affine.store %arg4, %submap[5 + %arg5 + symbol(%arg3)] : memref + affine.yield + } + return + } +} + + +// CHECK: func.func @f(%arg0: memref, %arg1: index, %arg2: index, %arg3: index) { +// CHECK-NEXT: affine.for %arg4 = 0 to 10 { +// CHECK-NEXT: %0 = affine.load %arg0[(%arg4 + symbol(%arg3) + 5) * symbol(%arg1), (%arg4 + symbol(%arg3) + 5) * symbol(%arg2)] : memref +// CHECK-NEXT: func.call @use(%0) : (i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +// CHECK: func.func @g(%arg0: memref, %arg1: index, %arg2: index, %arg3: index, %arg4: i32) { +// CHECK-NEXT: affine.for %arg5 = 0 to 10 { +// CHECK-NEXT: affine.store %arg4, %arg0[(%arg5 + symbol(%arg3) + 5) * symbol(%arg1), (%arg5 + symbol(%arg3) + 5) * symbol(%arg2)] : memref +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +#map19 = affine_map<(d0, d1, d2, d3) -> (d1 + d3, d0 + d2)> +#map20 = affine_map<(d0, d1, d2, d3) -> (d1, d0)> +#map21 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +#map22 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +module @conv_1 { + memref.global @out : memref<512x64xi32> = uninitialized + memref.global @filter : memref<4x4xi32> = uninitialized + memref.global @im : memref<515x67xi32> = uninitialized + func.func @main() -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c512 = arith.constant 512 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %0 = memref.get_global @im : memref<515x67xi32> + %1 = memref.get_global @filter : memref<4x4xi32> + %2 = memref.get_global @out : memref<512x64xi32> + %3 = "polygeist.submap"(%0, %c4, %c4, %c64, %c512) <{map = #map19}> : (memref<515x67xi32>, index, index, index, index) -> memref<4x4x64x512xi32> + %ssmap = "polygeist.submap"(%3, %c4, %c4, %c64, %c512) <{map = #map22}> : (memref<4x4x64x512xi32>, index, index, index, index) -> memref + %4 = "polygeist.submap"(%1, %c4, %c4, %c64, %c512) <{map = #map20}> : (memref<4x4xi32>, index, index, index, index) -> memref + %5 = "polygeist.submap"(%2, %c4, %c4, %c64, %c512) <{map = #map21}> : (memref<512x64xi32>, index, index, index, index) -> memref + linalg.generic {indexing_maps = [#map22, #map22, #map22], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%ssmap, %4 : memref, memref) outs(%5 : memref) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %6 = arith.muli %in, %in_0 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7 : i32 + } + return %c0_i32 : i32 + } +} \ No newline at end of file diff --git a/tools/polygeist-opt/CMakeLists.txt b/tools/polygeist-opt/CMakeLists.txt index ccfebd421d81..7a61d5b3b7af 100644 --- a/tools/polygeist-opt/CMakeLists.txt +++ b/tools/polygeist-opt/CMakeLists.txt @@ -5,6 +5,7 @@ set(LIBS ${conversion_libs} MLIROptLib MLIRPolygeist + MLIRPolygeistKernel MLIRPolygeistTransforms MLIRFuncAllExtensions ) diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 95fe1b1fc4a4..d653d835ab45 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -15,12 +15,14 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" @@ -32,6 +34,8 @@ #include "polygeist/Dialect.h" #include "polygeist/Passes/Passes.h" +#include "polygeist/Kernel/KernelDialect.h" +#include "polygeist/Kernel/KernelOps.h" using namespace mlir; @@ -59,7 +63,10 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); + registry.insert(); registry.insert(); mlir::registerpolygeistPasses(); @@ -75,6 +82,7 @@ int main(int argc, char **argv) { mlir::registerLoopInvariantCodeMotionPass(); mlir::registerConvertSCFToOpenMPPass(); mlir::affine::registerAffinePasses(); + mlir::registerLinalgPasses(); registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { LLVM::LLVMFunctionType::attachInterface(*ctx);