diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 9c1c2c6e60f02..765bd3ec4b9fd 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) { case Intrinsic::aarch64_sve_convert_from_svbool: case Intrinsic::wasm_alltrue: case Intrinsic::wasm_anytrue: + case Intrinsic::wasm_dot: // WebAssembly float semantics are always known case Intrinsic::wasm_trunc_signed: case Intrinsic::wasm_trunc_unsigned: @@ -3826,6 +3827,37 @@ static Constant *ConstantFoldFixedVectorCall( } return ConstantVector::get(Result); } + case Intrinsic::wasm_dot: { + unsigned NumElements = + cast(Operands[0]->getType())->getNumElements(); + + assert(NumElements == 8 && Result.size() == 4 && + "wasm dot takes i16x8 and produces i32x4"); + assert(Ty->isIntegerTy()); + SmallVector MulVector; + + for (unsigned I = 0; I < NumElements; ++I) { + ConstantInt *Elt0 = + cast(Operands[0]->getAggregateElement(I)); + ConstantInt *Elt1 = + cast(Operands[1]->getAggregateElement(I)); + + // sext 32 first, according to specs + APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32); + + // Multiplication can never be more than 32 bit. + // We can opt to not perform modulo of imul here. + MulVector.push_back(IMul); + } + for (unsigned I = 0; I < Result.size(); I++) { + // Addition can never be more than 32 bit. + // We can opt to not perform modulo of iadd here. + APInt IAdd = MulVector[I * 2] + MulVector[I * 2 + 1]; + Result[I] = ConstantInt::get(Ty, IAdd); + } + + return ConstantVector::get(Result); + } default: break; } diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll new file mode 100644 index 0000000000000..9c5dc74033f5b --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 + +; RUN: opt -passes=instsimplify -S < %s | FileCheck %s + +; Test that intrinsics wasm dot call are constant folded + +target triple = "wasm32-unknown-unknown" + + +define <4 x i32> @dot_zero() { +; CHECK-LABEL: define <4 x i32> @dot_zero() { +; CHECK-NEXT: ret <4 x i32> zeroinitializer +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer) + ret <4 x i32> %res +} + +; a = 1 2 3 4 5 6 7 8 +; b = 1 2 3 4 5 6 7 8 +; k1|k2 = a * b = 1 4 9 16 25 36 49 64 +; k1 + k2 = (1+4) | (9 + 16) | (25 + 36) | (49 + 64) +; result = 5 | 25 | 61 | 113 +define <4 x i32> @dot_nonzero() { +; CHECK-LABEL: define <4 x i32> @dot_nonzero() { +; CHECK-NEXT: ret <4 x i32> +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) + ret <4 x i32> %res +} + +define <4 x i32> @dot_doubly_negative() { +; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() { +; CHECK-NEXT: ret <4 x i32> splat (i32 2) +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) + ret <4 x i32> %res +} + +; This test checks for llvm's compliance on spec's wasm.dot's imul and iadd +; Since the original number can only be i16::max == 2^15 - 1, +; subsequent modulo of 2^32 of imul and iadd +; should return the same result +; 2*(2^15 - 1)^2 % 2^32 == 2*(2^15 - 1)^2 +define <4 x i32> @dot_follow_modulo_spec_1() { +; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_1() { +; CHECK-NEXT: ret <4 x i32> +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) + ret <4 x i32> %res +} + +; This test checks for llvm's compliance on spec's wasm.dot's imul and iadd +; 2*(- 2^15)^2 == 2^31, doesn't exceed 2^32 so we don't have to mod +; wrapping around is -(2^31), still doesn't exceed 2^32 +define <4 x i32> @dot_follow_modulo_spec_2() { +; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_2() { +; CHECK-NEXT: ret <4 x i32> +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) + ret <4 x i32> %res +} +