Skip to content

Commit e7ede36

Browse files
committed
initial commit
1 parent 6a7f572 commit e7ede36

File tree

5 files changed

+95
-3
lines changed

5 files changed

+95
-3
lines changed

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,7 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
14381438
.Case("popc.ll", true)
14391439
.Case("h2f", true)
14401440
.Case("swap.lo.hi.b64", true)
1441+
.Case("tanh.approx.f32", true)
14411442
.Default(false);
14421443

14431444
if (Expand) {
@@ -2532,6 +2533,12 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
25322533
MDNode *MD = MDNode::get(Builder.getContext(), {});
25332534
LD->setMetadata(LLVMContext::MD_invariant_load, MD);
25342535
return LD;
2536+
} else if (Name == "tanh.approx.f32") {
2537+
// nvvm.tanh.approx.f32 -> afn llvm.tanh.f32
2538+
FastMathFlags FMF;
2539+
FMF.setApproxFunc();
2540+
Rep = Builder.CreateUnaryIntrinsic(Intrinsic::tanh, CI->getArgOperand(0),
2541+
FMF);
25352542
} else if (Name == "barrier0" || Name == "barrier.n" || Name == "bar.sync") {
25362543
Value *Arg =
25372544
Name.ends_with('0') ? Builder.getInt32(0) : CI->getArgOperand(0);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -950,10 +950,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
950950
// promoted to f32. v2f16 is expanded to f16, which is then promoted
951951
// to f32.
952952
for (const auto &Op :
953-
{ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) {
953+
{ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
954954
setOperationAction(Op, MVT::f16, Promote);
955955
setOperationAction(Op, MVT::f32, Legal);
956-
setOperationAction(Op, MVT::f64, Legal);
956+
// fsin, fcos, and ftanh are not supported on f64
957+
if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
958+
setOperationAction(Op, MVT::f64, Legal);
959+
}
957960
setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand);
958961
setOperationAction(Op, MVT::bf16, Promote);
959962
AddPromotedToType(Op, MVT::bf16, MVT::f32);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,7 @@ defm FMA_F32 : FMA<F32RT, allow_ftz = true>;
12331233
defm FMA_F32x2 : FMA<F32X2RT, allow_ftz = true, preds = [hasF32x2Instructions]>;
12341234
defm FMA_F64 : FMA<F64RT, allow_ftz = false>;
12351235

1236-
// sin/cos
1236+
// sin/cos/tanh
12371237

12381238
class UnaryOpAllowsApproxFn<SDPatternOperator operator>
12391239
: PatFrag<(ops node:$A),
@@ -1249,6 +1249,9 @@ def COS_APPROX_f32 :
12491249
BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz),
12501250
"cos.approx$ftz.f32",
12511251
[(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
1252+
def TANH_APPROX_f32 :
1253+
BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "tanh.approx.f32",
1254+
[(set f32:$dst, (UnaryOpAllowsApproxFn<ftanh> f32:$src))]>;
12521255

12531256
//-----------------------------------
12541257
// Bitwise operations

llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ declare float @llvm.nvvm.fabs.f(float)
1717
declare float @llvm.nvvm.fabs.ftz.f(float)
1818
declare double @llvm.nvvm.fabs.d(double)
1919

20+
declare float @llvm.nvvm.tanh.approx.f32(float)
21+
2022
declare i16 @llvm.nvvm.max.s(i16, i16)
2123
declare i32 @llvm.nvvm.max.i(i32, i32)
2224
declare i64 @llvm.nvvm.max.ll(i64, i64)
@@ -138,6 +140,13 @@ define void @fabs(float %a, double %b) {
138140
ret void
139141
}
140142

143+
; CHECK-LABEL: @tanh
144+
define void @tanh(float %a) {
145+
; CHECK: call afn float @llvm.tanh.f32(float %a)
146+
%r1 = call float @llvm.nvvm.tanh.approx.f32(float %a)
147+
ret void
148+
}
149+
141150
; CHECK-LABEL: @min_max
142151
define void @min_max(i16 %a1, i16 %a2, i32 %b1, i32 %b2, i64 %c1, i64 %c2) {
143152
; CHECK: [[maxs:%[a-zA-Z0-9.]+]] = icmp sge i16 %a1, %a2

llvm/test/CodeGen/NVPTX/tanhf.ll

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 | %ptxas-verify %}
4+
5+
6+
define float @test1(float %in) local_unnamed_addr {
7+
; CHECK-LABEL: test1(
8+
; CHECK: {
9+
; CHECK-NEXT: .reg .b32 %r<3>;
10+
; CHECK-EMPTY:
11+
; CHECK-NEXT: // %bb.0:
12+
; CHECK-NEXT: ld.param.b32 %r1, [test1_param_0];
13+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
14+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
15+
; CHECK-NEXT: ret;
16+
%call = call afn float @llvm.tanh.f32(float %in)
17+
ret float %call
18+
}
19+
20+
define float @test2(float %in) local_unnamed_addr {
21+
; CHECK-LABEL: test2(
22+
; CHECK: {
23+
; CHECK-NEXT: .reg .b32 %r<3>;
24+
; CHECK-EMPTY:
25+
; CHECK-NEXT: // %bb.0:
26+
; CHECK-NEXT: ld.param.b32 %r1, [test2_param_0];
27+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
28+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
29+
; CHECK-NEXT: ret;
30+
%call = tail call afn float @llvm.tanh.f32(float %in)
31+
ret float %call
32+
}
33+
34+
define half @test3(half %in) local_unnamed_addr {
35+
; CHECK-LABEL: test3(
36+
; CHECK: {
37+
; CHECK-NEXT: .reg .b16 %rs<3>;
38+
; CHECK-NEXT: .reg .b32 %r<3>;
39+
; CHECK-EMPTY:
40+
; CHECK-NEXT: // %bb.0:
41+
; CHECK-NEXT: ld.param.b16 %rs1, [test3_param_0];
42+
; CHECK-NEXT: cvt.f32.f16 %r1, %rs1;
43+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
44+
; CHECK-NEXT: cvt.rn.f16.f32 %rs2, %r2;
45+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs2;
46+
; CHECK-NEXT: ret;
47+
%call = call afn half @llvm.tanh.f16(half %in)
48+
ret half %call
49+
}
50+
51+
define half @test4(half %in) local_unnamed_addr {
52+
; CHECK-LABEL: test4(
53+
; CHECK: {
54+
; CHECK-NEXT: .reg .b16 %rs<3>;
55+
; CHECK-NEXT: .reg .b32 %r<3>;
56+
; CHECK-EMPTY:
57+
; CHECK-NEXT: // %bb.0:
58+
; CHECK-NEXT: ld.param.b16 %rs1, [test4_param_0];
59+
; CHECK-NEXT: cvt.f32.f16 %r1, %rs1;
60+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
61+
; CHECK-NEXT: cvt.rn.f16.f32 %rs2, %r2;
62+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs2;
63+
; CHECK-NEXT: ret;
64+
%call = tail call afn half @llvm.tanh.f16(half %in)
65+
ret half %call
66+
}
67+
68+
declare float @llvm.tanh.f32(float)
69+
declare half @llvm.tanh.f16(half)
70+

0 commit comments

Comments
 (0)