@@ -1046,16 +1046,48 @@ LogicalResult cir::VecCreateOp::verify() {
1046
1046
// VecTernaryOp
1047
1047
// ===----------------------------------------------------------------------===//
1048
1048
1049
+ OpFoldResult cir::VecTernaryOp::fold (FoldAdaptor adaptor) {
1050
+ mlir::Attribute cond = adaptor.getCond ();
1051
+ mlir::Attribute lhs = adaptor.getLhs ();
1052
+ mlir::Attribute rhs = adaptor.getRhs ();
1053
+
1054
+ if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
1055
+ !mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
1056
+ !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
1057
+ return {};
1058
+ auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
1059
+ auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
1060
+ auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
1061
+
1062
+ mlir::ArrayAttr condElts = condVec.getElts ();
1063
+
1064
+ SmallVector<mlir::Attribute, 16 > elements;
1065
+ elements.reserve (condElts.size ());
1066
+
1067
+ for (const auto &[idx, condAttr] :
1068
+ llvm::enumerate (condElts.getAsRange <cir::IntAttr>())) {
1069
+ if (condAttr.getSInt ()) {
1070
+ elements.push_back (lhsVec.getElts ()[idx]);
1071
+ } else {
1072
+ elements.push_back (rhsVec.getElts ()[idx]);
1073
+ }
1074
+ }
1075
+
1076
+ cir::VectorType vecTy = getLhs ().getType ();
1077
+ return cir::ConstVectorAttr::get (
1078
+ vecTy, mlir::ArrayAttr::get (getContext (), elements));
1079
+ }
1080
+
1049
1081
LogicalResult cir::VecTernaryOp::verify () {
1050
1082
// Verify that the condition operand has the same number of elements as the
1051
1083
// other operands. (The automatic verification already checked that all
1052
1084
// operands are vector types and that the second and third operands are the
1053
1085
// same type.)
1054
1086
if (mlir::cast<cir::VectorType>(getCond ().getType ()).getSize () !=
1055
- getVec1 ().getType ().getSize ()) {
1087
+ getLhs ().getType ().getSize ()) {
1056
1088
return emitOpError () << " : the number of elements in "
1057
- << getCond ().getType () << " and "
1058
- << getVec1 (). getType () << " don't match" ;
1089
+ << getCond ().getType () << " and " << getLhs (). getType ()
1090
+ << " don't match" ;
1059
1091
}
1060
1092
return success ();
1061
1093
}
0 commit comments