diff --git a/test/from_linalg_invalid.mlir b/test/from_linalg_invalid.mlir index 0c18b29b..bfb010b2 100644 --- a/test/from_linalg_invalid.mlir +++ b/test/from_linalg_invalid.mlir @@ -11,7 +11,7 @@ } func @reductions(%arg0: memref<2x3x4x5x6xf32>, %arg1: memref<2x4x6xf32>) { - // expected-error @+1 {{Linalg op is not compatible with Sair}} + // expected-error @+1 {{unexpected output tensor expression in indexing map #0 a.k.a 'd3' is function of reduction iterator 'd3'}} linalg.generic #reductions_trait ins(%arg0 : memref<2x3x4x5x6xf32>) outs(%arg1 : memref<2x4x6xf32>) { diff --git a/transforms/sair_from_linalg.cc b/transforms/sair_from_linalg.cc index be99e55c..a14eafad 100644 --- a/transforms/sair_from_linalg.cc +++ b/transforms/sair_from_linalg.cc @@ -431,7 +431,7 @@ void MoveBodyBlock(mlir::AffineMap linalg_to_sair_loops, // as "types" and a hyper-rectangular domain with the given number of dimensons. // Uses "rewriter" to construct the types. void CreateResultTypes(mlir::Builder &rewriter, int num_dimensions, - mlir::TypeRange types, + const SmallVectorImpl &types, llvm::SmallVectorImpl &result_types) { mlir::MLIRContext *context = rewriter.getContext(); auto result_domain_shape = @@ -562,7 +562,7 @@ mlir::LogicalResult RewriteLinalgToSair(mlir::linalg::LinalgOp op, // Linalg does not seem to restrict the output indexing to parallel dimensions // only, but Sair does. Abort the conversion in case of incompatibility. int num_parallel_loops = op.getNumParallelLoops(); - int num_operands = op.getNumInputsAndOutputBuffers(); + int num_operands = op.getNumShapedOperands(); for (int i = op.getNumInputs(); i < num_operands; ++i) { auto mapping = operand_mappings[i].cast(); if (mlir::failed(VerifyReductionMapping(mapping, num_parallel_loops))) { @@ -589,20 +589,20 @@ mlir::LogicalResult RewriteLinalgToSair(mlir::linalg::LinalgOp op, // Convert input and input/output MemRefs used by Linalg to Sair values. llvm::SmallVector map_operands; llvm::SmallVector, 4> result_ranges; - EmitMemRefToValue(op.getInputsAndOutputBuffers(), op.getNumOutputs(), loc, + EmitMemRefToValue(op.getShapedOperands(), op.getNumOutputs(), loc, sair_program, rewriter, map_operands, result_ranges); // Prepare parameters of the Sair map operation. int num_loops = op.getNumLoops(); llvm::SmallVector loop_bounds; - CollectLoopBounds(num_loops, subscripts_to_loops, - op.getInputsAndOutputBuffers(), loop_bounds); + CollectLoopBounds(num_loops, subscripts_to_loops, op.getShapedOperands(), + loop_bounds); llvm::SmallVector domain_ranges = CreateSairDomain(loc, loop_bounds, sair_program, rewriter); llvm::SmallVector result_types; - CreateResultTypes(rewriter, num_parallel_loops, - op.getOutputBuffers().getTypes(), result_types); + CreateResultTypes(rewriter, num_parallel_loops, op.getOutputBufferTypes(), + result_types); // Construct the main map or map_reduce operation. mlir::Operation *map_op;