Skip to content

Commit fd8ae2c

Browse files
Add constant-folding for unary NVVM intrinsics (#141233)
Add support for constant-folding numerous NVVM unary arithmetic intrinsics (including f, d, and ftz_f variants): - nvvm.ceil.* - nvvm.fabs.* - nvvm.floor.* - nvvm.rcp.* - nvvm.round.* - nvvm.saturate.* - nvvm.sqrt.f - nvvm.sqrt.rn.*
1 parent 5c7c855 commit fd8ae2c

File tree

3 files changed

+886
-11
lines changed

3 files changed

+886
-11
lines changed

llvm/include/llvm/IR/NVVMIntrinsicUtils.h

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
112112
return false;
113113
}
114114
llvm_unreachable("Checking FTZ flag for invalid f2i/d2i intrinsic");
115-
return false;
116115
}
117116

118117
inline bool FPToIntegerIntrinsicResultIsSigned(Intrinsic::ID IntrinsicID) {
@@ -179,7 +178,6 @@ inline bool FPToIntegerIntrinsicResultIsSigned(Intrinsic::ID IntrinsicID) {
179178
}
180179
llvm_unreachable(
181180
"Checking invalid f2i/d2i intrinsic for signed int conversion");
182-
return false;
183181
}
184182

185183
inline APFloat::roundingMode
@@ -250,7 +248,6 @@ GetFPToIntegerRoundingMode(Intrinsic::ID IntrinsicID) {
250248
return APFloat::rmTowardZero;
251249
}
252250
llvm_unreachable("Checking rounding mode for invalid f2i/d2i intrinsic");
253-
return APFloat::roundingMode::Invalid;
254251
}
255252

256253
inline bool FMinFMaxShouldFTZ(Intrinsic::ID IntrinsicID) {
@@ -280,7 +277,6 @@ inline bool FMinFMaxShouldFTZ(Intrinsic::ID IntrinsicID) {
280277
return false;
281278
}
282279
llvm_unreachable("Checking FTZ flag for invalid fmin/fmax intrinsic");
283-
return false;
284280
}
285281

286282
inline bool FMinFMaxPropagatesNaNs(Intrinsic::ID IntrinsicID) {
@@ -310,7 +306,6 @@ inline bool FMinFMaxPropagatesNaNs(Intrinsic::ID IntrinsicID) {
310306
return false;
311307
}
312308
llvm_unreachable("Checking NaN flag for invalid fmin/fmax intrinsic");
313-
return false;
314309
}
315310

316311
inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
@@ -340,7 +335,83 @@ inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
340335
return false;
341336
}
342337
llvm_unreachable("Checking XorSignAbs flag for invalid fmin/fmax intrinsic");
343-
return false;
338+
}
339+
340+
inline bool UnaryMathIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
341+
switch (IntrinsicID) {
342+
case Intrinsic::nvvm_ceil_ftz_f:
343+
case Intrinsic::nvvm_fabs_ftz:
344+
case Intrinsic::nvvm_floor_ftz_f:
345+
case Intrinsic::nvvm_round_ftz_f:
346+
case Intrinsic::nvvm_saturate_ftz_f:
347+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
348+
return true;
349+
case Intrinsic::nvvm_ceil_f:
350+
case Intrinsic::nvvm_ceil_d:
351+
case Intrinsic::nvvm_fabs:
352+
case Intrinsic::nvvm_floor_f:
353+
case Intrinsic::nvvm_floor_d:
354+
case Intrinsic::nvvm_round_f:
355+
case Intrinsic::nvvm_round_d:
356+
case Intrinsic::nvvm_saturate_d:
357+
case Intrinsic::nvvm_saturate_f:
358+
case Intrinsic::nvvm_sqrt_f:
359+
case Intrinsic::nvvm_sqrt_rn_d:
360+
case Intrinsic::nvvm_sqrt_rn_f:
361+
return false;
362+
}
363+
llvm_unreachable("Checking FTZ flag for invalid unary intrinsic");
364+
}
365+
366+
inline bool RCPShouldFTZ(Intrinsic::ID IntrinsicID) {
367+
switch (IntrinsicID) {
368+
case Intrinsic::nvvm_rcp_rm_ftz_f:
369+
case Intrinsic::nvvm_rcp_rn_ftz_f:
370+
case Intrinsic::nvvm_rcp_rp_ftz_f:
371+
case Intrinsic::nvvm_rcp_rz_ftz_f:
372+
return true;
373+
case Intrinsic::nvvm_rcp_rm_d:
374+
case Intrinsic::nvvm_rcp_rm_f:
375+
case Intrinsic::nvvm_rcp_rn_d:
376+
case Intrinsic::nvvm_rcp_rn_f:
377+
case Intrinsic::nvvm_rcp_rp_d:
378+
case Intrinsic::nvvm_rcp_rp_f:
379+
case Intrinsic::nvvm_rcp_rz_d:
380+
case Intrinsic::nvvm_rcp_rz_f:
381+
return false;
382+
}
383+
llvm_unreachable("Checking FTZ flag for invalid rcp intrinsic");
384+
}
385+
386+
inline APFloat::roundingMode GetRCPRoundingMode(Intrinsic::ID IntrinsicID) {
387+
switch (IntrinsicID) {
388+
case Intrinsic::nvvm_rcp_rm_f:
389+
case Intrinsic::nvvm_rcp_rm_d:
390+
case Intrinsic::nvvm_rcp_rm_ftz_f:
391+
return APFloat::rmTowardNegative;
392+
393+
case Intrinsic::nvvm_rcp_rn_f:
394+
case Intrinsic::nvvm_rcp_rn_d:
395+
case Intrinsic::nvvm_rcp_rn_ftz_f:
396+
return APFloat::rmNearestTiesToEven;
397+
398+
case Intrinsic::nvvm_rcp_rp_f:
399+
case Intrinsic::nvvm_rcp_rp_d:
400+
case Intrinsic::nvvm_rcp_rp_ftz_f:
401+
return APFloat::rmTowardPositive;
402+
403+
case Intrinsic::nvvm_rcp_rz_f:
404+
case Intrinsic::nvvm_rcp_rz_d:
405+
case Intrinsic::nvvm_rcp_rz_ftz_f:
406+
return APFloat::rmTowardZero;
407+
}
408+
llvm_unreachable("Checking rounding mode for invalid rcp intrinsic");
409+
}
410+
411+
inline DenormalMode GetNVVMDenromMode(bool ShouldFTZ) {
412+
if (ShouldFTZ)
413+
return DenormalMode::getPreserveSign();
414+
return DenormalMode::getIEEE();
344415
}
345416

346417
} // namespace nvvm

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 163 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,6 +1801,44 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
18011801
case Intrinsic::nvvm_d2ull_rn:
18021802
case Intrinsic::nvvm_d2ull_rp:
18031803
case Intrinsic::nvvm_d2ull_rz:
1804+
1805+
// NVVM math intrinsics:
1806+
case Intrinsic::nvvm_ceil_d:
1807+
case Intrinsic::nvvm_ceil_f:
1808+
case Intrinsic::nvvm_ceil_ftz_f:
1809+
1810+
case Intrinsic::nvvm_fabs:
1811+
case Intrinsic::nvvm_fabs_ftz:
1812+
1813+
case Intrinsic::nvvm_floor_d:
1814+
case Intrinsic::nvvm_floor_f:
1815+
case Intrinsic::nvvm_floor_ftz_f:
1816+
1817+
case Intrinsic::nvvm_rcp_rm_d:
1818+
case Intrinsic::nvvm_rcp_rm_f:
1819+
case Intrinsic::nvvm_rcp_rm_ftz_f:
1820+
case Intrinsic::nvvm_rcp_rn_d:
1821+
case Intrinsic::nvvm_rcp_rn_f:
1822+
case Intrinsic::nvvm_rcp_rn_ftz_f:
1823+
case Intrinsic::nvvm_rcp_rp_d:
1824+
case Intrinsic::nvvm_rcp_rp_f:
1825+
case Intrinsic::nvvm_rcp_rp_ftz_f:
1826+
case Intrinsic::nvvm_rcp_rz_d:
1827+
case Intrinsic::nvvm_rcp_rz_f:
1828+
case Intrinsic::nvvm_rcp_rz_ftz_f:
1829+
1830+
case Intrinsic::nvvm_round_d:
1831+
case Intrinsic::nvvm_round_f:
1832+
case Intrinsic::nvvm_round_ftz_f:
1833+
1834+
case Intrinsic::nvvm_saturate_d:
1835+
case Intrinsic::nvvm_saturate_f:
1836+
case Intrinsic::nvvm_saturate_ftz_f:
1837+
1838+
case Intrinsic::nvvm_sqrt_f:
1839+
case Intrinsic::nvvm_sqrt_rn_d:
1840+
case Intrinsic::nvvm_sqrt_rn_f:
1841+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
18041842
return !Call->isStrictFP();
18051843

18061844
// Sign operations are actually bitwise operations, they do not raise
@@ -1818,6 +1856,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
18181856
case Intrinsic::nearbyint:
18191857
case Intrinsic::rint:
18201858
case Intrinsic::canonicalize:
1859+
18211860
// Constrained intrinsics can be folded if FP environment is known
18221861
// to compiler.
18231862
case Intrinsic::experimental_constrained_fma:
@@ -1965,22 +2004,56 @@ inline bool llvm_fenv_testexcept() {
19652004
return false;
19662005
}
19672006

1968-
static APFloat FTZPreserveSign(const APFloat &V) {
2007+
static const APFloat FTZPreserveSign(const APFloat &V) {
19692008
if (V.isDenormal())
19702009
return APFloat::getZero(V.getSemantics(), V.isNegative());
19712010
return V;
19722011
}
19732012

1974-
Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
1975-
Type *Ty) {
2013+
static const APFloat FlushToPositiveZero(const APFloat &V) {
2014+
if (V.isDenormal())
2015+
return APFloat::getZero(V.getSemantics(), false);
2016+
return V;
2017+
}
2018+
2019+
static const APFloat
2020+
FlushWithDenormKind(const APFloat &V,
2021+
DenormalMode::DenormalModeKind DenormKind) {
2022+
assert(DenormKind != DenormalMode::DenormalModeKind::Invalid &&
2023+
DenormKind != DenormalMode::DenormalModeKind::Dynamic);
2024+
switch (DenormKind) {
2025+
case DenormalMode::DenormalModeKind::IEEE:
2026+
return V;
2027+
case DenormalMode::DenormalModeKind::PreserveSign:
2028+
return FTZPreserveSign(V);
2029+
case DenormalMode::DenormalModeKind::PositiveZero:
2030+
return FlushToPositiveZero(V);
2031+
default:
2032+
llvm_unreachable("Invalid denormal mode!");
2033+
}
2034+
}
2035+
2036+
Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty,
2037+
DenormalMode DenormMode = DenormalMode::getIEEE()) {
2038+
if (!DenormMode.isValid() ||
2039+
DenormMode.Input == DenormalMode::DenormalModeKind::Dynamic ||
2040+
DenormMode.Output == DenormalMode::DenormalModeKind::Dynamic)
2041+
return nullptr;
2042+
19762043
llvm_fenv_clearexcept();
1977-
double Result = NativeFP(V.convertToDouble());
2044+
auto Input = FlushWithDenormKind(V, DenormMode.Input);
2045+
double Result = NativeFP(Input.convertToDouble());
19782046
if (llvm_fenv_testexcept()) {
19792047
llvm_fenv_clearexcept();
19802048
return nullptr;
19812049
}
19822050

1983-
return GetConstantFoldFPValue(Result, Ty);
2051+
Constant *Output = GetConstantFoldFPValue(Result, Ty);
2052+
if (DenormMode.Output == DenormalMode::DenormalModeKind::IEEE)
2053+
return Output;
2054+
const auto *CFP = static_cast<ConstantFP *>(Output);
2055+
const auto Res = FlushWithDenormKind(CFP->getValueAPF(), DenormMode.Output);
2056+
return ConstantFP::get(Ty->getContext(), Res);
19842057
}
19852058

19862059
#if defined(HAS_IEE754_FLOAT128) && defined(HAS_LOGF128)
@@ -2550,6 +2623,91 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
25502623
return ConstantFoldFP(atan, APF, Ty);
25512624
case Intrinsic::sqrt:
25522625
return ConstantFoldFP(sqrt, APF, Ty);
2626+
2627+
// NVVM Intrinsics:
2628+
case Intrinsic::nvvm_ceil_ftz_f:
2629+
case Intrinsic::nvvm_ceil_f:
2630+
case Intrinsic::nvvm_ceil_d:
2631+
return ConstantFoldFP(
2632+
ceil, APF, Ty,
2633+
nvvm::GetNVVMDenromMode(
2634+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));
2635+
2636+
case Intrinsic::nvvm_fabs_ftz:
2637+
case Intrinsic::nvvm_fabs:
2638+
return ConstantFoldFP(
2639+
fabs, APF, Ty,
2640+
nvvm::GetNVVMDenromMode(
2641+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));
2642+
2643+
case Intrinsic::nvvm_floor_ftz_f:
2644+
case Intrinsic::nvvm_floor_f:
2645+
case Intrinsic::nvvm_floor_d:
2646+
return ConstantFoldFP(
2647+
floor, APF, Ty,
2648+
nvvm::GetNVVMDenromMode(
2649+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));
2650+
2651+
case Intrinsic::nvvm_rcp_rm_ftz_f:
2652+
case Intrinsic::nvvm_rcp_rn_ftz_f:
2653+
case Intrinsic::nvvm_rcp_rp_ftz_f:
2654+
case Intrinsic::nvvm_rcp_rz_ftz_f:
2655+
case Intrinsic::nvvm_rcp_rm_d:
2656+
case Intrinsic::nvvm_rcp_rm_f:
2657+
case Intrinsic::nvvm_rcp_rn_d:
2658+
case Intrinsic::nvvm_rcp_rn_f:
2659+
case Intrinsic::nvvm_rcp_rp_d:
2660+
case Intrinsic::nvvm_rcp_rp_f:
2661+
case Intrinsic::nvvm_rcp_rz_d:
2662+
case Intrinsic::nvvm_rcp_rz_f: {
2663+
APFloat::roundingMode RoundMode = nvvm::GetRCPRoundingMode(IntrinsicID);
2664+
bool IsFTZ = nvvm::RCPShouldFTZ(IntrinsicID);
2665+
2666+
auto Denominator = IsFTZ ? FTZPreserveSign(APF) : APF;
2667+
APFloat Res = APFloat::getOne(APF.getSemantics());
2668+
APFloat::opStatus Status = Res.divide(Denominator, RoundMode);
2669+
2670+
if (Status == APFloat::opOK || Status == APFloat::opInexact) {
2671+
if (IsFTZ)
2672+
Res = FTZPreserveSign(Res);
2673+
return ConstantFP::get(Ty->getContext(), Res);
2674+
}
2675+
return nullptr;
2676+
}
2677+
2678+
case Intrinsic::nvvm_round_ftz_f:
2679+
case Intrinsic::nvvm_round_f:
2680+
case Intrinsic::nvvm_round_d:
2681+
return ConstantFoldFP(
2682+
round, APF, Ty,
2683+
nvvm::GetNVVMDenromMode(
2684+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));
2685+
2686+
case Intrinsic::nvvm_saturate_ftz_f:
2687+
case Intrinsic::nvvm_saturate_d:
2688+
case Intrinsic::nvvm_saturate_f: {
2689+
bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
2690+
auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
2691+
if (V.isNegative() || V.isZero() || V.isNaN())
2692+
return ConstantFP::getZero(Ty);
2693+
APFloat One = APFloat::getOne(APF.getSemantics());
2694+
if (V > One)
2695+
return ConstantFP::get(Ty->getContext(), One);
2696+
return ConstantFP::get(Ty->getContext(), APF);
2697+
}
2698+
2699+
case Intrinsic::nvvm_sqrt_rn_ftz_f:
2700+
case Intrinsic::nvvm_sqrt_f:
2701+
case Intrinsic::nvvm_sqrt_rn_d:
2702+
case Intrinsic::nvvm_sqrt_rn_f:
2703+
if (APF.isNegative())
2704+
return nullptr;
2705+
return ConstantFoldFP(
2706+
sqrt, APF, Ty,
2707+
nvvm::GetNVVMDenromMode(
2708+
nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID)));
2709+
2710+
// AMDGCN Intrinsics:
25532711
case Intrinsic::amdgcn_cos:
25542712
case Intrinsic::amdgcn_sin: {
25552713
double V = getValueAsDouble(Op);

0 commit comments

Comments
 (0)