Skip to content

Commit 3022a9f

Browse files
committed
Increase type stability
1 parent 7e6da68 commit 3022a9f

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/fwd_back.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,21 @@ unflatten_bc_tangent(x_orig::Any, dx_flat::AbstractArray{<:ChainRulesCore.Abstra
9494
ConstructionBase.constructorof(eltype(dx_flat))()
9595

9696

97-
struct FwddiffFwd{F<:Function,i} <: Function
97+
struct FwddiffVJPSingleArg{F<:Function,i} <: Function
9898
f::F
9999
end
100-
FwddiffFwd(f::F, ::Val{i}) where {F<:Function,i} = FwddiffFwd{F,i}(f)
100+
FwddiffVJPSingleArg(f::F, ::Val{i}) where {F<:Function,i} = FwddiffVJPSingleArg{F,i}(f)
101101

102-
(fwd::FwddiffFwd{F,i})(xs...) where {F<:Function,i} = forwarddiff_fwd(fwd.f, xs, Val(i))
102+
# Need to use Vararg{Any,N} to force specialization:
103+
function (fwdback::FwddiffVJPSingleArg{F,i})(ΔΩ, xs::Vararg{Any,N}) where {F,i,N}
104+
y_dual = forwarddiff_fwd(fwdback.f, xs, Val(i))
105+
svec_tangent_dual_product(ΔΩ, y_dual)
106+
end
103107

104108

105109
function forwarddiff_bc_vjp_impl(f::F, Xs::Tuple, ::Val{i}, ΔΩA::Any) where {F<:Function,i}
106110
# @info "RUN forwarddiff_bc_vjp_impl(f, Xs, Val($i), ΔΩA)"
107-
fwd = FwddiffFwd(f, Val(i))
108-
dx_flat = svec_tangent_dual_product.(ΔΩA, fwd.(Xs...)) # ToDo: Use Base.Broadcast.broadcasted
109-
unflatten_bc_tangent(Xs[i], dx_flat)
111+
vjp_i = FwddiffVJPSingleArg(f, Val(i))
112+
dXi_flat = broadcast(vjp_i, ΔΩA, Xs...) # ToDo: Use Base.Broadcast.broadcasted
113+
unflatten_bc_tangent(Xs[i], dXi_flat)
110114
end

0 commit comments

Comments
 (0)