Skip to content

Commit 7671a4f

Browse files
Update ForwardDiffExt.jl
1 parent d6e39b1 commit 7671a4f

File tree

1 file changed

+1
-39
lines changed

1 file changed

+1
-39
lines changed

ext/ForwardDiffExt.jl

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module ForwardDiffExt
22
import ForwardDiff, ChainRulesCore
3-
using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff, NNlib
3+
using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff
44
using SLEEFPirates: tanh_fast, sigmoid_fast
55

66
import IfElse: ifelse
@@ -141,44 +141,6 @@ end
141141
end
142142
end
143143

144-
@generated function NNlib.relu(
145-
x::ForwardDiff.Dual{T,<:LoopVectorization.AbstractSIMD,N}
146-
) where {T,S,N}
147-
quote
148-
$(Expr(:meta, :inline))
149-
v = x.value
150-
z = zero(v)
151-
cmp = v < z
152-
r = ifelse(cmp, z, v)
153-
p = x.partials
154-
ForwardDiff.Dual{T}(
155-
r,
156-
ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, z, p[n]))
157-
)
158-
end
159-
end
160-
161-
@generated function NNlib.leakyrelu(
162-
x::ForwardDiff.Dual{T,<:LoopVectorization.AbstractSIMD,N},
163-
a = 0.01
164-
) where {T,S,N}
165-
quote
166-
$(Expr(:meta, :inline))
167-
v = x.value
168-
z = zero(v)
169-
170-
α = convert(typeof(v), a)
171-
cmp = v < z
172-
r = ifelse(cmp, α * v, v)
173-
p = x.partials
174-
ForwardDiff.Dual{T}(
175-
r,
176-
ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, α * p[n], p[n]))
177-
)
178-
end
179-
end
180-
181-
182144
@generated function _ifelse(
183145
m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}},
184146
x::ForwardDiff.Dual{TAG,V,P},

0 commit comments

Comments
 (0)