-
Notifications
You must be signed in to change notification settings - Fork 14.8k
Reapply "[WebAssembly] Constant fold wasm.dot" #153070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Jasmine Tang (badumbatish) ChangesIn #149619, for the test of Full diff: https://github.com/llvm/llvm-project/pull/153070.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..bf291809b07a2 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::aarch64_sve_convert_from_svbool:
case Intrinsic::wasm_alltrue:
case Intrinsic::wasm_anytrue:
+ case Intrinsic::wasm_dot:
// WebAssembly float semantics are always known
case Intrinsic::wasm_trunc_signed:
case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,30 @@ static Constant *ConstantFoldFixedVectorCall(
}
return ConstantVector::get(Result);
}
+ case Intrinsic::wasm_dot: {
+ unsigned NumElements =
+ cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+ assert(NumElements == 8 && Result.size() == 4 &&
+ "wasm dot takes i16x8 and produces i32x4");
+ assert(Ty->isIntegerTy());
+ int32_t MulVector[8];
+
+ for (unsigned I = 0; I < NumElements; ++I) {
+ ConstantInt *Elt0 =
+ cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+ ConstantInt *Elt1 =
+ cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+ MulVector[I] = Elt0->getSExtValue() * Elt1->getSExtValue();
+ }
+ for (unsigned I = 0; I < Result.size(); I++) {
+ int64_t IAdd = (int64_t)MulVector[I * 2] + MulVector[I * 2 + 1];
+ Result[I] = ConstantInt::get(Ty, IAdd);
+ }
+
+ return ConstantVector::get(Result);
+ }
default:
break;
}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..b537b7bccf861
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT: ret <4 x i32> zeroinitializer
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+ ret <4 x i32> %res
+}
+
+; a = 1 2 3 4 5 6 7 8
+; b = 1 2 3 4 5 6 7 8
+; k1|k2 = a * b = 1 4 9 16 25 36 49 64
+; k1 + k2 = (1+4) | (9 + 16) | (25 + 36) | (49 + 64)
+; result = 5 | 25 | 61 | 113
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT: ret <4 x i32> <i32 5, i32 25, i32 61, i32 113>
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT: ret <4 x i32> splat (i32 2)
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+ ret <4 x i32> %res
+}
+
+; Tests that i16 max signed values fit in i32
+define <4 x i32> @dot_follow_modulo_spec_1() {
+; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_1() {
+; CHECK-NEXT: ret <4 x i32> <i32 2147352578, i32 0, i32 0, i32 0>
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 32767, i16 32767, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>, <8 x i16> <i16 32767, i16 32767, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
+ ret <4 x i32> %res
+}
+
+; Tests that i16 min signed values fit in i32
+define <4 x i32> @dot_follow_modulo_spec_2() {
+; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_2() {
+; CHECK-NEXT: ret <4 x i32> <i32 -2147483648, i32 0, i32 0, i32 0>
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -32768, i16 -32768, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>, <8 x i16> <i16 -32768, i16 -32768, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
+ ret <4 x i32> %res
+}
+
|
@llvm/pr-subscribers-llvm-analysis Author: Jasmine Tang (badumbatish) ChangesIn #149619, for the test of Full diff: https://github.com/llvm/llvm-project/pull/153070.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..bf291809b07a2 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::aarch64_sve_convert_from_svbool:
case Intrinsic::wasm_alltrue:
case Intrinsic::wasm_anytrue:
+ case Intrinsic::wasm_dot:
// WebAssembly float semantics are always known
case Intrinsic::wasm_trunc_signed:
case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,30 @@ static Constant *ConstantFoldFixedVectorCall(
}
return ConstantVector::get(Result);
}
+ case Intrinsic::wasm_dot: {
+ unsigned NumElements =
+ cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+ assert(NumElements == 8 && Result.size() == 4 &&
+ "wasm dot takes i16x8 and produces i32x4");
+ assert(Ty->isIntegerTy());
+ int32_t MulVector[8];
+
+ for (unsigned I = 0; I < NumElements; ++I) {
+ ConstantInt *Elt0 =
+ cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+ ConstantInt *Elt1 =
+ cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+ MulVector[I] = Elt0->getSExtValue() * Elt1->getSExtValue();
+ }
+ for (unsigned I = 0; I < Result.size(); I++) {
+ int64_t IAdd = (int64_t)MulVector[I * 2] + MulVector[I * 2 + 1];
+ Result[I] = ConstantInt::get(Ty, IAdd);
+ }
+
+ return ConstantVector::get(Result);
+ }
default:
break;
}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..b537b7bccf861
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,56 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT: ret <4 x i32> zeroinitializer
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+ ret <4 x i32> %res
+}
+
+; a = 1 2 3 4 5 6 7 8
+; b = 1 2 3 4 5 6 7 8
+; k1|k2 = a * b = 1 4 9 16 25 36 49 64
+; k1 + k2 = (1+4) | (9 + 16) | (25 + 36) | (49 + 64)
+; result = 5 | 25 | 61 | 113
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT: ret <4 x i32> <i32 5, i32 25, i32 61, i32 113>
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT: ret <4 x i32> splat (i32 2)
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+ ret <4 x i32> %res
+}
+
+; Tests that i16 max signed values fit in i32
+define <4 x i32> @dot_follow_modulo_spec_1() {
+; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_1() {
+; CHECK-NEXT: ret <4 x i32> <i32 2147352578, i32 0, i32 0, i32 0>
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 32767, i16 32767, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>, <8 x i16> <i16 32767, i16 32767, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
+ ret <4 x i32> %res
+}
+
+; Tests that i16 min signed values fit in i32
+define <4 x i32> @dot_follow_modulo_spec_2() {
+; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_2() {
+; CHECK-NEXT: ret <4 x i32> <i32 -2147483648, i32 0, i32 0, i32 0>
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -32768, i16 -32768, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>, <8 x i16> <i16 -32768, i16 -32768, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
+ ret <4 x i32> %res
+}
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just want to check that you were able to run the tests again with UB sanitizer enabled and it didn't trigger?
yep no more trigger, sanitizer checks all of WebAssembly in
|
In #149619, for the test of
@dot_follow_modulo_spec_2
, constant folding the addition of two i32 1073741824 causes an overflow from 2^32 to -2^32=-2147483648, which triggers the UB sanitizer. This PR reapplies the previous PR, explicitly casting the addition operand to int64_t first before performing the addition before producing a int32 number viaConstant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned)