Skip to content

Commit e74a9ce

Browse files
committed
Fix tests
1 parent b15039a commit e74a9ce

File tree

1 file changed

+58
-47
lines changed

1 file changed

+58
-47
lines changed

mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,6 @@ func.func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
1919
return %0 : tensor<f32>
2020
}
2121

22-
// Test a binary elementwise op with a tensor and a scalar operand.
23-
// CHECK-LABEL: func @addf_tensor_plus_scalar_rank1
24-
// CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
25-
func.func @addf_tensor_plus_scalar_rank1(%t: tensor<?xf32>, %s: f32) -> tensor<?xf32> {
26-
%c0 = arith.constant 0 : index
27-
%d0 = tensor.dim %t, %c0 : tensor<?xf32>
28-
%init = tensor.empty(%d0) : tensor<?xf32>
29-
%splat = linalg.fill ins(%s : f32) outs(%init : tensor<?xf32>) -> tensor<?xf32>
30-
// CHECK: linalg.generic
31-
// CHECK-SAME: iterator_types = ["parallel"]
32-
// CHECK-SAME: ins(%[[T]], %{{.*}}
33-
%0 = arith.addf %t, %splat : tensor<?xf32>
34-
return %0 : tensor<?xf32>
35-
}
36-
37-
// Test a comparison op between a tensor and a scalar.
38-
// CHECK-LABEL: func @cmpf_tensor_scalar
39-
// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
40-
func.func @cmpf_tensor_scalar(%a: tensor<?xf32>, %s: f32) -> tensor<?xi1> {
41-
%c0 = arith.constant 0 : index
42-
%d0 = tensor.dim %a, %c0 : tensor<?xf32>
43-
%initS = tensor.empty(%d0) : tensor<?xf32>
44-
%splat = linalg.fill ins(%s : f32) outs(%initS : tensor<?xf32>) -> tensor<?xf32>
45-
46-
%init = tensor.empty(%d0) : tensor<?xi1>
47-
// CHECK: %[[INIT:.*]] = tensor.empty
48-
// CHECK: linalg.generic
49-
// CHECK-SAME: ins(%[[A]], %{{.*}}
50-
%0 = arith.cmpf olt, %a, %splat : tensor<?xf32>
51-
return %0 : tensor<?xi1>
52-
}
53-
54-
// Test a binary elementwise op with a tensor and a zero-dimensional
55-
// (rank-0) tensor.
56-
// CHECK-LABEL: func @addf_tensor_plus_rank0_tensor
57-
// CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor<4xf32>, %[[R0:[0-9a-zA-Z]*]]: tensor<f32>
58-
func.func @addf_tensor_plus_rank0_tensor(%t: tensor<4xf32>, %r0: tensor<f32>) -> tensor<4xf32> {
59-
%c = tensor.extract %r0[] : tensor<f32>
60-
%init = tensor.empty() : tensor<4xf32>
61-
%splat = linalg.fill ins(%c : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
62-
// CHECK: linalg.generic
63-
// CHECK-SAME: ins(%[[T]], %{{.*}}
64-
%0 = arith.addf %t, %splat : tensor<4xf32>
65-
return %0 : tensor<4xf32>
66-
}
67-
68-
6922
// -----
7023

7124
// Check indexing maps and iterator types for the rank > 0 case.
@@ -155,3 +108,61 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>)
155108
return %0 : tensor<4x?x?x8x2x?xi1>
156109
}
157110

111+
// -----
112+
113+
// Check a mix of scalar and tensor input.
114+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
115+
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
116+
// CHECK-LABEL: func @scalar_plus_tensor
117+
// CHECK: %[[GEN:.*]] = linalg.generic
118+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
119+
// CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
120+
// CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
121+
// CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
122+
// CHECK: "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
123+
// CHECK: linalg.yield {{.*}} : f32
124+
// CHECK: } -> tensor<?x?xf32>
125+
func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
126+
%0 = "test.elementwise_mappable"(%arg0, %arg1)
127+
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
128+
return %0 : tensor<?x?xf32>
129+
}
130+
131+
// -----
132+
// This test exercises the case where an elementwise op has two scalar-like
133+
// operands and one ranked tensor operand. In this example, we chain two
134+
// `test.elementwise_mappable` calls:
135+
// %0 = f(%s1, %t)
136+
// %1 = f(%s2, %0)
137+
// CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()>
138+
// CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)>
139+
// CHECK-LABEL: func @scalar_tensor_scalar
140+
// First generic.
141+
// CHECK: %[[GEN0:.*]] = linalg.generic
142+
// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
143+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
144+
// CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
145+
// CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
146+
// CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
147+
// CHECK: %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
148+
// CHECK: linalg.yield %[[APPLY0]] : f32
149+
// CHECK: } -> tensor<?x?xf32>
150+
151+
// Second generic.
152+
// CHECK: %[[GEN1:.*]] = linalg.generic
153+
// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
154+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
155+
// CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
156+
// CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
157+
// CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
158+
// CHECK: %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
159+
// CHECK: linalg.yield %[[APPLY1]] : f32
160+
// CHECK: } -> tensor<?x?xf32>
161+
// CHECK: return %[[GEN1]] : tensor<?x?xf32>
162+
func.func @scalar_tensor_scalar(%s1: f32, %t: tensor<?x?xf32>, %s2: f32) -> tensor<?x?xf32> {
163+
%0 = "test.elementwise_mappable"(%s1, %t)
164+
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
165+
%1 = "test.elementwise_mappable"(%s2, %0)
166+
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
167+
return %1 : tensor<?x?xf32>
168+
}

0 commit comments

Comments
 (0)