Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 08402db

Browse files
committed
[WIP] add vector memref support for mlir-cpu-runner func return
- allow mlir-cpu-runner to execute functions that return memrefs with vector of f32 in addition to just those with f32. Signed-off-by: Uday Bondhugula <[email protected]>
1 parent 39eef64 commit 08402db

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

lib/ExecutionEngine/MemRefUtils.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ allocMemRefDescriptor(Type type, bool allocateData = true,
4545
return make_string_error("memref with dynamic shapes not supported");
4646

4747
auto elementType = memRefType.getElementType();
48-
if (!elementType.isF32())
48+
VectorType vectorType = elementType.dyn_cast<VectorType>();
49+
if (!elementType.isF32() &&
50+
!(vectorType && vectorType.getElementType().isF32()))
4951
return make_string_error(
50-
"memref with element other than f32 not supported");
52+
"memref with element other than f32 or vector of f32 not supported");
5153

5254
auto *descriptor =
5355
reinterpret_cast<StaticFloatMemRef *>(malloc(sizeof(StaticFloatMemRef)));
@@ -59,6 +61,11 @@ allocMemRefDescriptor(Type type, bool allocateData = true,
5961
auto shape = memRefType.getShape();
6062
int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
6163
std::multiplies<int64_t>());
64+
if (vectorType) {
65+
auto vectorShape = vectorType.getShape();
66+
size = std::accumulate(vectorShape.begin(), vectorShape.end(), size,
67+
std::multiplies<int64_t>());
68+
}
6269
descriptor->data = reinterpret_cast<float *>(malloc(sizeof(float) * size));
6370
for (int64_t i = 0; i < size; ++i) {
6471
descriptor->data[i] = initialValue;

lib/Support/JitRunner.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ static void printOneMemRef(Type t, void *val) {
126126
auto shape = memRefType.getShape();
127127
int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
128128
std::multiplies<int64_t>());
129+
130+
if (auto vectorType = memRefType.getElementType().dyn_cast<VectorType>()) {
131+
auto vectorShape = vectorType.getShape();
132+
size = std::accumulate(vectorShape.begin(), vectorShape.end(), size,
133+
std::multiplies<int64_t>());
134+
}
135+
129136
for (int64_t i = 0; i < size; ++i) {
130137
llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
131138
}

test/mlir-cpu-runner/simple.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// RUN: mlir-cpu-runner %s -O3 | FileCheck %s
44
// RUN: mlir-cpu-runner %s -O3 -loop-distribute -loop-vectorize | FileCheck %s
55
// RUN: mlir-cpu-runner %s -loop-distribute -loop-vectorize | FileCheck %s
6+
// RUN: mlir-cpu-runner -e bar -init-value 2.0 %s | FileCheck -check-prefix=BAR %s
67

78
func @fabsf(f32) -> f32
89

@@ -31,3 +32,26 @@ func @foo(%a : memref<1x1xf32>) -> memref<1x1xf32> {
3132
}
3233
// NOMAIN: 2.234000e+03
3334
// NOMAIN-NEXT: 2.234000e+03
35+
36+
func @bar(%a : memref<16xvector<4xf32>>) -> memref<16xvector<4xf32>> {
37+
%c0 = constant 0 : index
38+
%c1 = constant 1 : index
39+
40+
%u = load %a[%c0] : memref<16xvector<4xf32>>
41+
%v = load %a[%c1] : memref<16xvector<4xf32>>
42+
%w = addf %u, %v : vector<4xf32>
43+
store %w, %a[%c0] : memref<16xvector<4xf32>>
44+
45+
return %a : memref<16xvector<4xf32>>
46+
}
47+
// BAR: 4.{{0+}}e+00 4.{{0+}}e+00 4.{{0+}}e+00 4.{{0+}}e+00 2.{{0+}}e+00
48+
// BAR-NEXT: 4.{{0+}}e+00 4.{{0+}}e+00 4.{{0+}}e+00 4.{{0+}}e+00 2.{{0+}}e+00
49+
50+
// This one crashes.
51+
func @crash(%arg2: memref<128x128xvector<8xf32>>) -> memref<128x128xvector<8xf32>> {
52+
%c0 = constant 0 : index
53+
%v = constant dense<1.0> : vector<8xf32>
54+
affine.store %v, %arg2[%c0, %c0] : memref<128x128xvector<8xf32>>
55+
return %arg2 : memref<128x128xvector<8xf32>>
56+
}
57+

0 commit comments

Comments
 (0)