Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/from_linalg_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
}

func @reductions(%arg0: memref<2x3x4x5x6xf32>, %arg1: memref<2x4x6xf32>) {
// expected-error @+1 {{Linalg op is not compatible with Sair}}
// expected-error @+1 {{unexpected output tensor expression in indexing map #0 a.k.a 'd3' is function of reduction iterator 'd3'}}
linalg.generic #reductions_trait
ins(%arg0 : memref<2x3x4x5x6xf32>)
outs(%arg1 : memref<2x4x6xf32>) {
Expand Down
14 changes: 7 additions & 7 deletions transforms/sair_from_linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ void MoveBodyBlock(mlir::AffineMap linalg_to_sair_loops,
// as "types" and a hyper-rectangular domain with the given number of dimensons.
// Uses "rewriter" to construct the types.
void CreateResultTypes(mlir::Builder &rewriter, int num_dimensions,
mlir::TypeRange types,
const SmallVectorImpl<MemRefType> &types,
llvm::SmallVectorImpl<mlir::Type> &result_types) {
mlir::MLIRContext *context = rewriter.getContext();
auto result_domain_shape =
Expand Down Expand Up @@ -562,7 +562,7 @@ mlir::LogicalResult RewriteLinalgToSair(mlir::linalg::LinalgOp op,
// Linalg does not seem to restrict the output indexing to parallel dimensions
// only, but Sair does. Abort the conversion in case of incompatibility.
int num_parallel_loops = op.getNumParallelLoops();
int num_operands = op.getNumInputsAndOutputBuffers();
int num_operands = op.getNumShapedOperands();
for (int i = op.getNumInputs(); i < num_operands; ++i) {
auto mapping = operand_mappings[i].cast<MappingAttr>();
if (mlir::failed(VerifyReductionMapping(mapping, num_parallel_loops))) {
Expand All @@ -589,20 +589,20 @@ mlir::LogicalResult RewriteLinalgToSair(mlir::linalg::LinalgOp op,
// Convert input and input/output MemRefs used by Linalg to Sair values.
llvm::SmallVector<mlir::Value, 4> map_operands;
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> result_ranges;
EmitMemRefToValue(op.getInputsAndOutputBuffers(), op.getNumOutputs(), loc,
EmitMemRefToValue(op.getShapedOperands(), op.getNumOutputs(), loc,
sair_program, rewriter, map_operands, result_ranges);

// Prepare parameters of the Sair map operation.
int num_loops = op.getNumLoops();
llvm::SmallVector<LoopBound, 8> loop_bounds;
CollectLoopBounds(num_loops, subscripts_to_loops,
op.getInputsAndOutputBuffers(), loop_bounds);
CollectLoopBounds(num_loops, subscripts_to_loops, op.getShapedOperands(),
loop_bounds);
llvm::SmallVector<mlir::Value, 4> domain_ranges =
CreateSairDomain(loc, loop_bounds, sair_program, rewriter);

llvm::SmallVector<mlir::Type, 4> result_types;
CreateResultTypes(rewriter, num_parallel_loops,
op.getOutputBuffers().getTypes(), result_types);
CreateResultTypes(rewriter, num_parallel_loops, op.getOutputBufferTypes(),
result_types);

// Construct the main map or map_reduce operation.
mlir::Operation *map_op;
Expand Down