@@ -9,10 +9,10 @@ module DiffEqBaseEnzymeExt
99
1010 function Enzyme. EnzymeRules. augmented_primal (
1111 config:: Enzyme.EnzymeRules.RevConfigWidth{1} ,
12- func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{RT } , prob,
12+ func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT} } , prob,
1313 sensealg:: Union {
1414 Const{Nothing}, Const{<: DiffEqBase.AbstractSensitivityAlgorithm }},
15- u0, p, args... ; kwargs... ) where {RT <: Union{Duplicated, DuplicatedNoNeed} }
15+ u0, p, args... ; kwargs... ) where {RT}
1616 @inline function copy_or_reuse (val, idx)
1717 if Enzyme. EnzymeRules. overwritten (config)[idx] && ismutable (val)
1818 return deepcopy (val)
@@ -31,18 +31,18 @@ module DiffEqBaseEnzymeExt
3131 SciMLBase. EnzymeOriginator (), ntuple (arg_copy, Val (length (args)))... ;
3232 kwargs... )
3333
34- ResType = typeof (res[1 ])
35- dres = Enzyme. make_zero (res[1 ]):: ResType
34+ dres = Enzyme. make_zero (res[1 ]):: RT
3635 tup = (dres, res[2 ])
37- return Enzyme. EnzymeRules. AugmentedReturn {ResType, ResType , Any} (res[1 ], dres, tup:: Any )
36+ return Enzyme. EnzymeRules. AugmentedReturn {RT, RT , Any} (res[1 ], dres, tup:: Any )
3837 end
3938
4039 function Enzyme. EnzymeRules. reverse (config:: Enzyme.EnzymeRules.RevConfigWidth{1} ,
41- func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{RT } , tape, prob,
40+ func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT} } , tape, prob,
4241 sensealg:: Union {
4342 Const{Nothing}, Const{<: DiffEqBase.AbstractSensitivityAlgorithm }},
44- u0, p, args... ; kwargs... ) where {RT <: Union{Duplicated, DuplicatedNoNeed} }
43+ u0, p, args... ; kwargs... ) where {RT}
4544 dres, clos = tape
45+ dres = dres:: RT
4646 dargs = clos (dres)
4747 for (darg, ptr) in zip (dargs, (func, prob, sensealg, u0, p, args... ))
4848 if ptr isa Enzyme. Const
0 commit comments