@@ -94,17 +94,21 @@ unflatten_bc_tangent(x_orig::Any, dx_flat::AbstractArray{<:ChainRulesCore.Abstra
94
94
ConstructionBase. constructorof (eltype (dx_flat))()
95
95
96
96
97
- struct FwddiffFwd {F<: Function ,i} <: Function
97
+ struct FwddiffVJPSingleArg {F<: Function ,i} <: Function
98
98
f:: F
99
99
end
100
- FwddiffFwd (f:: F , :: Val{i} ) where {F<: Function ,i} = FwddiffFwd {F,i} (f)
100
+ FwddiffVJPSingleArg (f:: F , :: Val{i} ) where {F<: Function ,i} = FwddiffVJPSingleArg {F,i} (f)
101
101
102
- (fwd:: FwddiffFwd{F,i} )(xs... ) where {F<: Function ,i} = forwarddiff_fwd (fwd. f, xs, Val (i))
102
+ # Need to use Vararg{Any,N} to force specialization:
103
+ function (fwdback:: FwddiffVJPSingleArg{F,i} )(ΔΩ, xs:: Vararg{Any,N} ) where {F,i,N}
104
+ y_dual = forwarddiff_fwd (fwdback. f, xs, Val (i))
105
+ svec_tangent_dual_product (ΔΩ, y_dual)
106
+ end
103
107
104
108
105
109
function forwarddiff_bc_vjp_impl (f:: F , Xs:: Tuple , :: Val{i} , ΔΩA:: Any ) where {F<: Function ,i}
106
110
# @info "RUN forwarddiff_bc_vjp_impl(f, Xs, Val($i), ΔΩA)"
107
- fwd = FwddiffFwd (f, Val (i))
108
- dx_flat = svec_tangent_dual_product .( ΔΩA, fwd .( Xs... ) ) # ToDo: Use Base.Broadcast.broadcasted
109
- unflatten_bc_tangent (Xs[i], dx_flat )
111
+ vjp_i = FwddiffVJPSingleArg (f, Val (i))
112
+ dXi_flat = broadcast (vjp_i, ΔΩA, Xs... ) # ToDo: Use Base.Broadcast.broadcasted
113
+ unflatten_bc_tangent (Xs[i], dXi_flat )
110
114
end
0 commit comments