Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 53 additions & 12 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,26 @@ namespace mlir {

using namespace mlir;

static inline bool isScalarLike(Type t) {
return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
}

static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;

// TODO: The conversion pattern can be made to work for `any_of` here, but
// it's more complex as it requires tracking which operands are scalars.
return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
auto types = op->getOperandTypes();

// We want at least one ranked tensor.
bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);

// No invalid operands (i.e., every operand is a ranked tensor or
// scalar-like).
bool noneInvalid = llvm::none_of(types, [](Type t) {
return !(isa<RankedTensorType>(t) || isScalarLike(t));
});

return anyRankedTensor && noneInvalid;
}

/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
Expand Down Expand Up @@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");

auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
SmallVector<AffineMap, 3> indexingMaps(
op->getNumResults() + op->getNumOperands(),
rewriter.getMultiDimIdentityMap(rank));
SmallVector<utils::IteratorType, 6> iteratorTypes(
auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
auto rank = resTy.getRank();

// Maps: identity for tensors (rank > 0), scalar map for scalars.
AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
/*results=*/{}, rewriter.getContext());
AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);

// Match phase.
SmallVector<bool> isScalarOperand;
isScalarOperand.reserve(op->getNumOperands());
for (Type ty : op->getOperandTypes()) {
if (isScalarLike(ty))
isScalarOperand.push_back(true);
else if (auto rt = dyn_cast<RankedTensorType>(ty))
isScalarOperand.push_back(false);
else
return rewriter.notifyMatchFailure(
op,
"unsupported operand type (expected scalar-like or ranked tensor)");
}

// Create indexing maps.
SmallVector<AffineMap> indexingMaps;
indexingMaps.reserve(op->getNumOperands() + op->getNumResults());

for (bool isScalar : isScalarOperand)
indexingMaps.push_back(isScalar ? scalarMap : idMap);

indexingMaps.append(op->getNumResults(), idMap);

SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
SmallVector<Value> outputs =
getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/op->getOperands(),
Expand All @@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto resultTypes = llvm::to_vector<6>(
SmallVector<Type> resultEltTys = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return cast<TensorType>(type).getElementType();
}));
auto *scalarOp =
Operation *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
resultTypes, op->getAttrs());
resultEltTys, op->getAttrs());
linalg::YieldOp::create(builder, loc, scalarOp->getResults());
});
return success();
Expand Down
66 changes: 66 additions & 0 deletions mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,69 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>)
return %0 : tensor<4x?x?x8x2x?xi1>
}

// -----

// Check a mix of scalar and tensor input.
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @scalar_plus_tensor
func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[GEN:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
// CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
// CHECK: "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
// CHECK: linalg.yield {{.*}} : f32
// CHECK: } -> tensor<?x?xf32>
%0 = "test.elementwise_mappable"(%arg0, %arg1)
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// -----
// This test exercises the case where an elementwise op has two scalar-like
// operands and one ranked tensor operand. In this example, we chain two
// `test.elementwise_mappable` calls:
// %0 = f(%s1, %t)
// %1 = f(%s2, %0)
// CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()>
// CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @scalar_tensor_scalar
func.func @scalar_tensor_scalar(%s1: f32, %t: tensor<?x?xf32>, %s2: f32) -> tensor<?x?xf32> {
// First generic.
// CHECK: %[[GEN0:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
// CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
// CHECK: %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
// CHECK: linalg.yield %[[APPLY0]] : f32
// CHECK: } -> tensor<?x?xf32>

// Second generic.
// CHECK: %[[GEN1:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
// CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
// CHECK: %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
// CHECK: linalg.yield %[[APPLY1]] : f32
// CHECK: } -> tensor<?x?xf32>
// CHECK: return %[[GEN1]] : tensor<?x?xf32>
%0 = "test.elementwise_mappable"(%s1, %t)
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "test.elementwise_mappable"(%s2, %0)
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// ----
// CHECK-LABEL: func @negative_scalar_only_eltwise
// CHECK-NOT: linalg
func.func @negative_scalar_only_eltwise(%a: f32, %b: f32) -> f32 {
%0 = arith.addf %a, %b : f32
return %0 : f32
}