Skip to content

Commit d6e39b1

Browse files
Create ForwardDiffNNlibExt.jl
1 parent f07def8 commit d6e39b1

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

ext/ForwardDiffNNlibExt.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
module ForwardDiffNNlibExt
2+
import ForwardDiff
3+
using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff, NNlib
4+
5+
@generated function NNlib.relu(
6+
x::ForwardDiff.Dual{T,<:LoopVectorization.AbstractSIMD,N}
7+
) where {T,S,N}
8+
quote
9+
$(Expr(:meta, :inline))
10+
v = x.value
11+
z = zero(v)
12+
cmp = v < z
13+
r = ifelse(cmp, z, v)
14+
p = x.partials
15+
ForwardDiff.Dual{T}(
16+
r,
17+
ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, z, p[n]))
18+
)
19+
end
20+
end
21+
22+
@generated function NNlib.leakyrelu(
23+
x::ForwardDiff.Dual{T,<:LoopVectorization.AbstractSIMD,N},
24+
a = 0.01
25+
) where {T,S,N}
26+
quote
27+
$(Expr(:meta, :inline))
28+
v = x.value
29+
z = zero(v)
30+
31+
α = convert(typeof(v), a)
32+
cmp = v < z
33+
r = ifelse(cmp, α * v, v)
34+
p = x.partials
35+
ForwardDiff.Dual{T}(
36+
r,
37+
ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, α * p[n], p[n]))
38+
)
39+
end
40+
end
41+
42+
end

0 commit comments

Comments
 (0)