Skip to content

Commit 19548de

Browse files
committed
generalize constant folder
1 parent 08c5944 commit 19548de

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5916,14 +5916,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
59165916
}
59175917

59185918
// shape_cast(constant) -> constant
5919-
if (auto splatAttr =
5920-
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
5921-
return splatAttr.reshape(getType());
5919+
if (auto denseAttr =
5920+
dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
5921+
return denseAttr.reshape(getType());
59225922

59235923
// shape_cast(poison) -> poison
5924-
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
5924+
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
59255925
return ub::PoisonAttr::get(getContext());
5926-
}
59275926

59285927
return {};
59295928
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
13301330

13311331
// -----
13321332

1333-
// CHECK-LABEL: shape_cast_constant
1333+
// CHECK-LABEL: shape_cast_splat_constant
13341334
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
13351335
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
13361336
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
1337-
func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
1337+
func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
13381338
%cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
13391339
%cst_1 = arith.constant dense<1> : vector<12x2xi32>
13401340
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
13441344

13451345
// -----
13461346

1347+
// Test of shape_cast's fold method:
1348+
// shape_cast(constant) -> constant.
1349+
//
1350+
// CHECK-LABEL: @shape_cast_dense_int_constant
1351+
// CHECK: %[[CST:.*]] = arith.constant
1352+
// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
1353+
// CHECK: return %[[CST]] : vector<2x3xi8>
1354+
func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
1355+
%cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
1356+
%0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
1357+
return %0 : vector<2x3xi8>
1358+
}
1359+
1360+
// -----
1361+
1362+
// Test of shape_cast fold's method:
1363+
// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
1364+
//
1365+
// CHECK-LABEL: @shape_cast_dense_float_constant
1366+
// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
1367+
// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
1368+
// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
1369+
func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
1370+
%cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
1371+
%0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
1372+
return %0, %cst : vector<2xf32>, vector<1x2xf32>
1373+
}
1374+
1375+
// -----
1376+
13471377
// CHECK-LABEL: shape_cast_poison
13481378
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
13491379
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>

0 commit comments

Comments
 (0)