Skip to content

Commit c75561b

Browse files
committed
fmt
1 parent b5a6e88 commit c75561b

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ struct VecOpToScalarOp : public OpConversionPattern<Op> {
3434
using OpConversionPattern<Op>::OpConversionPattern;
3535

3636
LogicalResult
37-
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
37+
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
38+
ConversionPatternRewriter &rewriter) const final;
3839
};
3940
// Pattern to promote an op of a smaller floating point type to F32.
4041
template <typename Op>
@@ -43,21 +44,23 @@ struct PromoteOpToF32 : public OpConversionPattern<Op> {
4344
using OpConversionPattern<Op>::OpConversionPattern;
4445

4546
LogicalResult
46-
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
47+
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
48+
ConversionPatternRewriter &rewriter) const final;
4749
};
4850
// Pattern to convert scalar math operations to calls to libm functions.
4951
// Additionally the libm function signatures are declared.
5052
template <typename Op>
5153
struct ScalarOpToLibmCall : public OpConversionPattern<Op> {
5254
public:
53-
using OpRewritePattern<Op>::OpRewritePattern;
55+
using OpConversionPattern<Op>::OpConversionPattern;
5456
ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
5557
StringRef floatFunc, StringRef doubleFunc)
5658
: OpConversionPattern<Op>(context, benefit), floatFunc(floatFunc),
57-
doubleFunc(doubleFunc) {};
59+
doubleFunc(doubleFunc){};
5860

5961
LogicalResult
60-
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
62+
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
63+
ConversionPatternRewriter &rewriter) const final;
6164

6265
private:
6366
std::string floatFunc, doubleFunc;
@@ -74,8 +77,9 @@ void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
7477
} // namespace
7578

7679
template <typename Op>
77-
LogicalResult
78-
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
80+
LogicalResult VecOpToScalarOp<Op>::matchAndRewrite(
81+
Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
82+
ConversionPatternRewriter &rewriter) const {
7983
auto opType = op.getType();
8084
auto loc = op.getLoc();
8185
auto vecType = dyn_cast<VectorType>(opType);
@@ -95,7 +99,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
9599
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
96100
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
97101
SmallVector<Value> operands;
98-
for (auto input : op->getOperands())
102+
for (auto input : adaptor.getOperands())
99103
operands.push_back(
100104
vector::ExtractOp::create(rewriter, loc, input, positions));
101105
Value scalarOp =
@@ -108,16 +112,17 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
108112
}
109113

110114
template <typename Op>
111-
LogicalResult
112-
PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
115+
LogicalResult PromoteOpToF32<Op>::matchAndRewrite(
116+
Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
117+
ConversionPatternRewriter &rewriter) const {
113118
auto opType = op.getType();
114119
if (!isa<Float16Type, BFloat16Type>(opType))
115120
return failure();
116121

117122
auto loc = op.getLoc();
118123
auto f32 = rewriter.getF32Type();
119124
auto extendedOperands = llvm::to_vector(
120-
llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
125+
llvm::map_range(adaptor.getOperands(), [&](Value operand) -> Value {
121126
return arith::ExtFOp::create(rewriter, loc, f32, operand);
122127
}));
123128
auto newOp = Op::create(rewriter, loc, f32, extendedOperands);
@@ -126,9 +131,9 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
126131
}
127132

128133
template <typename Op>
129-
LogicalResult
130-
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
131-
PatternRewriter &rewriter) const {
134+
LogicalResult ScalarOpToLibmCall<Op>::matchAndRewrite(
135+
Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
136+
ConversionPatternRewriter &rewriter) const {
132137
auto module = SymbolTable::getNearestSymbolTable(op);
133138
auto type = op.getType();
134139
if (!isa<Float32Type, Float64Type>(type))
@@ -158,7 +163,7 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
158163
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
159164

160165
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
161-
op->getOperands());
166+
adaptor.getOperands());
162167

163168
return success();
164169
}

0 commit comments

Comments
 (0)