Skip to content

Commit 52fadde

Browse files
committed
initial commit
1 parent 9052a85 commit 52fadde

File tree

5 files changed

+95
-3
lines changed

5 files changed

+95
-3
lines changed

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,7 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
14501450
.Case("popc.ll", true)
14511451
.Case("h2f", true)
14521452
.Case("swap.lo.hi.b64", true)
1453+
.Case("tanh.approx.f32", true)
14531454
.Default(false);
14541455

14551456
if (Expand) {
@@ -2543,6 +2544,12 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
25432544
MDNode *MD = MDNode::get(Builder.getContext(), {});
25442545
LD->setMetadata(LLVMContext::MD_invariant_load, MD);
25452546
return LD;
2547+
} else if (Name == "tanh.approx.f32") {
2548+
// nvvm.tanh.approx.f32 -> afn llvm.tanh.f32
2549+
FastMathFlags FMF;
2550+
FMF.setApproxFunc();
2551+
Rep = Builder.CreateUnaryIntrinsic(Intrinsic::tanh, CI->getArgOperand(0),
2552+
FMF);
25462553
} else if (Name == "barrier0" || Name == "barrier.n" || Name == "bar.sync") {
25472554
Value *Arg =
25482555
Name.ends_with('0') ? Builder.getInt32(0) : CI->getArgOperand(0);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -952,10 +952,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
952952
// promoted to f32. v2f16 is expanded to f16, which is then promoted
953953
// to f32.
954954
for (const auto &Op :
955-
{ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) {
955+
{ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
956956
setOperationAction(Op, MVT::f16, Promote);
957957
setOperationAction(Op, MVT::f32, Legal);
958-
setOperationAction(Op, MVT::f64, Legal);
958+
// fsin, fcos, and ftanh are not supported on f64
959+
if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
960+
setOperationAction(Op, MVT::f64, Legal);
961+
}
959962
setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand);
960963
setOperationAction(Op, MVT::bf16, Promote);
961964
AddPromotedToType(Op, MVT::bf16, MVT::f32);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,7 @@ defm FMA_F32 : FMA<F32RT, allow_ftz = true>;
12341234
defm FMA_F32x2 : FMA<F32X2RT, allow_ftz = true, preds = [hasF32x2Instructions]>;
12351235
defm FMA_F64 : FMA<F64RT, allow_ftz = false>;
12361236

1237-
// sin/cos
1237+
// sin/cos/tanh
12381238

12391239
class UnaryOpAllowsApproxFn<SDPatternOperator operator>
12401240
: PatFrag<(ops node:$A),
@@ -1250,6 +1250,9 @@ def COS_APPROX_f32 :
12501250
BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz),
12511251
"cos.approx$ftz.f32",
12521252
[(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
1253+
def TANH_APPROX_f32 :
1254+
BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "tanh.approx.f32",
1255+
[(set f32:$dst, (UnaryOpAllowsApproxFn<ftanh> f32:$src))]>;
12531256

12541257
//-----------------------------------
12551258
// Bitwise operations

llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ declare float @llvm.nvvm.fabs.f(float)
1717
declare float @llvm.nvvm.fabs.ftz.f(float)
1818
declare double @llvm.nvvm.fabs.d(double)
1919

20+
declare float @llvm.nvvm.tanh.approx.f32(float)
21+
2022
declare i16 @llvm.nvvm.max.s(i16, i16)
2123
declare i32 @llvm.nvvm.max.i(i32, i32)
2224
declare i64 @llvm.nvvm.max.ll(i64, i64)
@@ -138,6 +140,13 @@ define void @fabs(float %a, double %b) {
138140
ret void
139141
}
140142

143+
; CHECK-LABEL: @tanh
144+
define void @tanh(float %a) {
145+
; CHECK: call afn float @llvm.tanh.f32(float %a)
146+
%r1 = call float @llvm.nvvm.tanh.approx.f32(float %a)
147+
ret void
148+
}
149+
141150
; CHECK-LABEL: @min_max
142151
define void @min_max(i16 %a1, i16 %a2, i32 %b1, i32 %b2, i64 %c1, i64 %c2) {
143152
; CHECK: [[maxs:%[a-zA-Z0-9.]+]] = icmp sge i16 %a1, %a2

llvm/test/CodeGen/NVPTX/tanhf.ll

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=nvptx64 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 | %ptxas-verify %}
4+
5+
6+
define float @test1(float %in) local_unnamed_addr {
7+
; CHECK-LABEL: test1(
8+
; CHECK: {
9+
; CHECK-NEXT: .reg .b32 %r<3>;
10+
; CHECK-EMPTY:
11+
; CHECK-NEXT: // %bb.0:
12+
; CHECK-NEXT: ld.param.b32 %r1, [test1_param_0];
13+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
14+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
15+
; CHECK-NEXT: ret;
16+
%call = call afn float @llvm.tanh.f32(float %in)
17+
ret float %call
18+
}
19+
20+
define float @test2(float %in) local_unnamed_addr {
21+
; CHECK-LABEL: test2(
22+
; CHECK: {
23+
; CHECK-NEXT: .reg .b32 %r<3>;
24+
; CHECK-EMPTY:
25+
; CHECK-NEXT: // %bb.0:
26+
; CHECK-NEXT: ld.param.b32 %r1, [test2_param_0];
27+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
28+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
29+
; CHECK-NEXT: ret;
30+
%call = tail call afn float @llvm.tanh.f32(float %in)
31+
ret float %call
32+
}
33+
34+
define half @test3(half %in) local_unnamed_addr {
35+
; CHECK-LABEL: test3(
36+
; CHECK: {
37+
; CHECK-NEXT: .reg .b16 %rs<3>;
38+
; CHECK-NEXT: .reg .b32 %r<3>;
39+
; CHECK-EMPTY:
40+
; CHECK-NEXT: // %bb.0:
41+
; CHECK-NEXT: ld.param.b16 %rs1, [test3_param_0];
42+
; CHECK-NEXT: cvt.f32.f16 %r1, %rs1;
43+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
44+
; CHECK-NEXT: cvt.rn.f16.f32 %rs2, %r2;
45+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs2;
46+
; CHECK-NEXT: ret;
47+
%call = call afn half @llvm.tanh.f16(half %in)
48+
ret half %call
49+
}
50+
51+
define half @test4(half %in) local_unnamed_addr {
52+
; CHECK-LABEL: test4(
53+
; CHECK: {
54+
; CHECK-NEXT: .reg .b16 %rs<3>;
55+
; CHECK-NEXT: .reg .b32 %r<3>;
56+
; CHECK-EMPTY:
57+
; CHECK-NEXT: // %bb.0:
58+
; CHECK-NEXT: ld.param.b16 %rs1, [test4_param_0];
59+
; CHECK-NEXT: cvt.f32.f16 %r1, %rs1;
60+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
61+
; CHECK-NEXT: cvt.rn.f16.f32 %rs2, %r2;
62+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs2;
63+
; CHECK-NEXT: ret;
64+
%call = tail call afn half @llvm.tanh.f16(half %in)
65+
ret half %call
66+
}
67+
68+
declare float @llvm.tanh.f32(float)
69+
declare half @llvm.tanh.f16(half)
70+

0 commit comments

Comments
 (0)