Skip to content

Commit c07f8dd

Browse files
[fixup] Rename a member function and chanege some allocs to allocas
1 parent 442e29a commit c07f8dd

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class VectorContractRewriter {
203203
}
204204

205205
public:
206-
void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
206+
void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
207207
// Create some convenience types.
208208
auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
209209
auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
@@ -462,7 +462,7 @@ class LowerContractionToNeonI8MMPattern
462462
VectorContractRewriterI8MM vcr;
463463
if (failed(vcr.matchAndInit(op, rewriter)))
464464
return failure();
465-
vcr.rewrite(op, rewriter);
465+
vcr.lower(op, rewriter);
466466

467467
return success();
468468
}
@@ -478,7 +478,7 @@ class LowerContractionToNeonBFMMLAPattern
478478
VectorContractRewriterBFMMLA vcr;
479479
if (failed(vcr.matchAndInit(op, rewriter)))
480480
return failure();
481-
vcr.rewrite(op, rewriter);
481+
vcr.lower(op, rewriter);
482482

483483
return success();
484484
}

mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
5858
[ 0.5, -1.3, -2.2, 0.1],
5959
[-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32>
6060

61-
%acc_mem = memref.alloc() : memref<4x4xf32>
61+
%acc_mem = memref.alloca() : memref<4x4xf32>
6262
vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32>
6363
%acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32>
6464

@@ -68,7 +68,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
6868
[-0.4, 0.6, 0.8, -0.5],
6969
[-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16>
7070

71-
%lhs_mem = memref.alloc() : memref<4x4xbf16>
71+
%lhs_mem = memref.alloca() : memref<4x4xbf16>
7272
vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
7373
%lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
7474

@@ -78,7 +78,7 @@ func.func @matrix_by_matrix_mul_and_acc() {
7878
[-0.2, 0.4, 1.0, 0.4],
7979
[-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
8080

81-
%rhs_mem = memref.alloc() : memref<4x4xbf16>
81+
%rhs_mem = memref.alloca() : memref<4x4xbf16>
8282
vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
8383
%rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
8484

@@ -121,14 +121,14 @@ func.func @vector_by_matrix_mul_and_acc() {
121121
// Accumulator test data
122122
%acc_cst = arith.constant dense<[0.7, 1.0, -0.1, 1.8]> : vector<4xf32>
123123

124-
%acc_mem = memref.alloc() : memref<4xf32>
124+
%acc_mem = memref.alloca() : memref<4xf32>
125125
vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32>
126126
%acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
127127

128128
// LHS test data
129129
%lhs_cst = arith.constant dense<[0.1, 0.7, -0.9, 1.3]> : vector<4xbf16>
130130

131-
%lhs_mem = memref.alloc() : memref<4xbf16>
131+
%lhs_mem = memref.alloca() : memref<4xbf16>
132132
vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16>
133133
%lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16>
134134

@@ -138,7 +138,7 @@ func.func @vector_by_matrix_mul_and_acc() {
138138
[-0.2, 0.4, 1.0, 0.4],
139139
[-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16>
140140

141-
%rhs_mem = memref.alloc() : memref<4x4xbf16>
141+
%rhs_mem = memref.alloca() : memref<4x4xbf16>
142142
vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16>
143143
%rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16>
144144

0 commit comments

Comments
 (0)