@@ -19,53 +19,6 @@ func.func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
19
19
return %0 : tensor <f32 >
20
20
}
21
21
22
- // Test a binary elementwise op with a tensor and a scalar operand.
23
- // CHECK-LABEL: func @addf_tensor_plus_scalar_rank1
24
- // CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
25
- func.func @addf_tensor_plus_scalar_rank1 (%t: tensor <?xf32 >, %s: f32 ) -> tensor <?xf32 > {
26
- %c0 = arith.constant 0 : index
27
- %d0 = tensor.dim %t , %c0 : tensor <?xf32 >
28
- %init = tensor.empty (%d0 ) : tensor <?xf32 >
29
- %splat = linalg.fill ins (%s : f32 ) outs (%init : tensor <?xf32 >) -> tensor <?xf32 >
30
- // CHECK: linalg.generic
31
- // CHECK-SAME: iterator_types = ["parallel"]
32
- // CHECK-SAME: ins(%[[T]], %{{.*}}
33
- %0 = arith.addf %t , %splat : tensor <?xf32 >
34
- return %0 : tensor <?xf32 >
35
- }
36
-
37
- // Test a comparison op between a tensor and a scalar.
38
- // CHECK-LABEL: func @cmpf_tensor_scalar
39
- // CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
40
- func.func @cmpf_tensor_scalar (%a: tensor <?xf32 >, %s: f32 ) -> tensor <?xi1 > {
41
- %c0 = arith.constant 0 : index
42
- %d0 = tensor.dim %a , %c0 : tensor <?xf32 >
43
- %initS = tensor.empty (%d0 ) : tensor <?xf32 >
44
- %splat = linalg.fill ins (%s : f32 ) outs (%initS : tensor <?xf32 >) -> tensor <?xf32 >
45
-
46
- %init = tensor.empty (%d0 ) : tensor <?xi1 >
47
- // CHECK: %[[INIT:.*]] = tensor.empty
48
- // CHECK: linalg.generic
49
- // CHECK-SAME: ins(%[[A]], %{{.*}}
50
- %0 = arith.cmpf olt , %a , %splat : tensor <?xf32 >
51
- return %0 : tensor <?xi1 >
52
- }
53
-
54
- // Test a binary elementwise op with a tensor and a zero-dimensional
55
- // (rank-0) tensor.
56
- // CHECK-LABEL: func @addf_tensor_plus_rank0_tensor
57
- // CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor<4xf32>, %[[R0:[0-9a-zA-Z]*]]: tensor<f32>
58
- func.func @addf_tensor_plus_rank0_tensor (%t: tensor <4 xf32 >, %r0: tensor <f32 >) -> tensor <4 xf32 > {
59
- %c = tensor.extract %r0 [] : tensor <f32 >
60
- %init = tensor.empty () : tensor <4 xf32 >
61
- %splat = linalg.fill ins (%c : f32 ) outs (%init : tensor <4 xf32 >) -> tensor <4 xf32 >
62
- // CHECK: linalg.generic
63
- // CHECK-SAME: ins(%[[T]], %{{.*}}
64
- %0 = arith.addf %t , %splat : tensor <4 xf32 >
65
- return %0 : tensor <4 xf32 >
66
- }
67
-
68
-
69
22
// -----
70
23
71
24
// Check indexing maps and iterator types for the rank > 0 case.
@@ -155,3 +108,61 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>)
155
108
return %0 : tensor <4 x?x?x8 x2 x?xi1 >
156
109
}
157
110
111
+ // -----
112
+
113
+ // Check a mix of scalar and tensor input.
114
+ // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
115
+ // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
116
+ // CHECK-LABEL: func @scalar_plus_tensor
117
+ // CHECK: %[[GEN:.*]] = linalg.generic
118
+ // CHECK-SAME: iterator_types = ["parallel", "parallel"]
119
+ // CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
120
+ // CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
121
+ // CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
122
+ // CHECK: "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
123
+ // CHECK: linalg.yield {{.*}} : f32
124
+ // CHECK: } -> tensor<?x?xf32>
125
+ func.func @scalar_plus_tensor (%arg0: f32 , %arg1: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
126
+ %0 = " test.elementwise_mappable" (%arg0 , %arg1 )
127
+ : (f32 , tensor <?x?xf32 >) -> tensor <?x?xf32 >
128
+ return %0 : tensor <?x?xf32 >
129
+ }
130
+
131
+ // -----
132
+ // This test exercises the case where an elementwise op has two scalar-like
133
+ // operands and one ranked tensor operand. In this example, we chain two
134
+ // `test.elementwise_mappable` calls:
135
+ // %0 = f(%s1, %t)
136
+ // %1 = f(%s2, %0)
137
+ // CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()>
138
+ // CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)>
139
+ // CHECK-LABEL: func @scalar_tensor_scalar
140
+ // First generic.
141
+ // CHECK: %[[GEN0:.*]] = linalg.generic
142
+ // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
143
+ // CHECK-SAME: iterator_types = ["parallel", "parallel"]
144
+ // CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
145
+ // CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
146
+ // CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
147
+ // CHECK: %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
148
+ // CHECK: linalg.yield %[[APPLY0]] : f32
149
+ // CHECK: } -> tensor<?x?xf32>
150
+
151
+ // Second generic.
152
+ // CHECK: %[[GEN1:.*]] = linalg.generic
153
+ // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
154
+ // CHECK-SAME: iterator_types = ["parallel", "parallel"]
155
+ // CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
156
+ // CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
157
+ // CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
158
+ // CHECK: %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
159
+ // CHECK: linalg.yield %[[APPLY1]] : f32
160
+ // CHECK: } -> tensor<?x?xf32>
161
+ // CHECK: return %[[GEN1]] : tensor<?x?xf32>
162
+ func.func @scalar_tensor_scalar (%s1: f32 , %t: tensor <?x?xf32 >, %s2: f32 ) -> tensor <?x?xf32 > {
163
+ %0 = " test.elementwise_mappable" (%s1 , %t )
164
+ : (f32 , tensor <?x?xf32 >) -> tensor <?x?xf32 >
165
+ %1 = " test.elementwise_mappable" (%s2 , %0 )
166
+ : (f32 , tensor <?x?xf32 >) -> tensor <?x?xf32 >
167
+ return %1 : tensor <?x?xf32 >
168
+ }
0 commit comments