Skip to content

Commit 222eb95

Browse files
Merge pull request #1221 from SciML/ap/revert_enz
Revert "Fix Enzyme solve_up rule signature to support DuplicatedNoNeed"
2 parents 64f65be + 6054a3b commit 222eb95

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.190.4"
4+
version = "6.190.5"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ module DiffEqBaseEnzymeExt
99

1010
function Enzyme.EnzymeRules.augmented_primal(
1111
config::Enzyme.EnzymeRules.RevConfigWidth{1},
12-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, prob,
12+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
1313
sensealg::Union{
1414
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
15-
u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}}
15+
u0, p, args...; kwargs...) where {RT}
1616
@inline function copy_or_reuse(val, idx)
1717
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
1818
return deepcopy(val)
@@ -31,18 +31,18 @@ module DiffEqBaseEnzymeExt
3131
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
3232
kwargs...)
3333

34-
ResType = typeof(res[1])
35-
dres = Enzyme.make_zero(res[1])::ResType
34+
dres = Enzyme.make_zero(res[1])::RT
3635
tup = (dres, res[2])
37-
return Enzyme.EnzymeRules.AugmentedReturn{ResType, ResType, Any}(res[1], dres, tup::Any)
36+
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
3837
end
3938

4039
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
41-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, tape, prob,
40+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
4241
sensealg::Union{
4342
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
44-
u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}}
43+
u0, p, args...; kwargs...) where {RT}
4544
dres, clos = tape
45+
dres = dres::RT
4646
dargs = clos(dres)
4747
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
4848
if ptr isa Enzyme.Const

0 commit comments

Comments
 (0)