@@ -34,7 +34,8 @@ struct VecOpToScalarOp : public OpConversionPattern<Op> {
34
34
using OpConversionPattern<Op>::OpConversionPattern;
35
35
36
36
LogicalResult
37
- matchAndRewrite (Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final ;
37
+ matchAndRewrite (Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
38
+ ConversionPatternRewriter &rewriter) const final ;
38
39
};
39
40
// Pattern to promote an op of a smaller floating point type to F32.
40
41
template <typename Op>
@@ -43,21 +44,23 @@ struct PromoteOpToF32 : public OpConversionPattern<Op> {
43
44
using OpConversionPattern<Op>::OpConversionPattern;
44
45
45
46
LogicalResult
46
- matchAndRewrite (Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final ;
47
+ matchAndRewrite (Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
48
+ ConversionPatternRewriter &rewriter) const final ;
47
49
};
48
50
// Pattern to convert scalar math operations to calls to libm functions.
49
51
// Additionally the libm function signatures are declared.
50
52
template <typename Op>
51
53
struct ScalarOpToLibmCall : public OpConversionPattern <Op> {
52
54
public:
53
- using OpRewritePattern <Op>::OpRewritePattern ;
55
+ using OpConversionPattern <Op>::OpConversionPattern ;
54
56
ScalarOpToLibmCall (MLIRContext *context, PatternBenefit benefit,
55
57
StringRef floatFunc, StringRef doubleFunc)
56
58
: OpConversionPattern<Op>(context, benefit), floatFunc(floatFunc),
57
- doubleFunc (doubleFunc) {};
59
+ doubleFunc (doubleFunc){};
58
60
59
61
LogicalResult
60
- matchAndRewrite (Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final ;
62
+ matchAndRewrite (Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
63
+ ConversionPatternRewriter &rewriter) const final ;
61
64
62
65
private:
63
66
std::string floatFunc, doubleFunc;
@@ -74,8 +77,9 @@ void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
74
77
} // namespace
75
78
76
79
template <typename Op>
77
- LogicalResult
78
- VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
80
+ LogicalResult VecOpToScalarOp<Op>::matchAndRewrite(
81
+ Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
82
+ ConversionPatternRewriter &rewriter) const {
79
83
auto opType = op.getType ();
80
84
auto loc = op.getLoc ();
81
85
auto vecType = dyn_cast<VectorType>(opType);
@@ -95,7 +99,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
95
99
for (auto linearIndex = 0 ; linearIndex < numElements; ++linearIndex) {
96
100
SmallVector<int64_t > positions = delinearize (linearIndex, strides);
97
101
SmallVector<Value> operands;
98
- for (auto input : op-> getOperands ())
102
+ for (auto input : adaptor. getOperands ())
99
103
operands.push_back (
100
104
vector::ExtractOp::create (rewriter, loc, input, positions));
101
105
Value scalarOp =
@@ -108,16 +112,17 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
108
112
}
109
113
110
114
template <typename Op>
111
- LogicalResult
112
- PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
115
+ LogicalResult PromoteOpToF32<Op>::matchAndRewrite(
116
+ Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
117
+ ConversionPatternRewriter &rewriter) const {
113
118
auto opType = op.getType ();
114
119
if (!isa<Float16Type, BFloat16Type>(opType))
115
120
return failure ();
116
121
117
122
auto loc = op.getLoc ();
118
123
auto f32 = rewriter.getF32Type ();
119
124
auto extendedOperands = llvm::to_vector (
120
- llvm::map_range (op-> getOperands (), [&](Value operand) -> Value {
125
+ llvm::map_range (adaptor. getOperands (), [&](Value operand) -> Value {
121
126
return arith::ExtFOp::create (rewriter, loc, f32 , operand);
122
127
}));
123
128
auto newOp = Op::create (rewriter, loc, f32 , extendedOperands);
@@ -126,9 +131,9 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
126
131
}
127
132
128
133
template <typename Op>
129
- LogicalResult
130
- ScalarOpToLibmCall <Op>::matchAndRewrite(Op op ,
131
- PatternRewriter &rewriter) const {
134
+ LogicalResult ScalarOpToLibmCall<Op>::matchAndRewrite(
135
+ Op op, typename OpConversionPattern <Op>::OpAdaptor adaptor ,
136
+ ConversionPatternRewriter &rewriter) const {
132
137
auto module = SymbolTable::getNearestSymbolTable (op);
133
138
auto type = op.getType ();
134
139
if (!isa<Float32Type, Float64Type>(type))
@@ -158,7 +163,7 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
158
163
assert (isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn (module , name)));
159
164
160
165
rewriter.replaceOpWithNewOp <func::CallOp>(op, name, op.getType (),
161
- op-> getOperands ());
166
+ adaptor. getOperands ());
162
167
163
168
return success ();
164
169
}
0 commit comments