Skip to content
5 changes: 3 additions & 2 deletions docs/src/examples/optimal_control/optimal_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ Now let's see what we received:
```@example neuraloptimalcontrol
l = loss_adjoint(res3.u)
cb(res3, l)
p = Plots.plot(ODE.solve(ODE.remake(prob, p = res3.u), ODE.Tsit5(), saveat = 0.01), ylim = (
-6, 6), lw = 3)
p = Plots.plot(
ODE.solve(ODE.remake(prob, p = res3.u), ODE.Tsit5(), saveat = 0.01), ylim = (
-6, 6), lw = 3)
Plots.plot!(p, ts, [first(first(ann([t], CA.ComponentArray(res3.u, ax), st))) for t in ts],
label = "u(t)", lw = 3)
```
22 changes: 19 additions & 3 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,20 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
if isinplace &&
!(p === nothing || p === SciMLBase.NullParameters())
if !isRODE
pf = SciMLBase.ParamJacobianWrapper(unwrappedf, _t, y)
if isscimlstructure(p)
pf = SciMLBase.ParamJacobianWrapper(
(
du, u, p, t)->unwrappedf(du, u, repack(p), t), _t, y)
else
pf = SciMLBase.ParamJacobianWrapper(unwrappedf, _t, y)
end
else
pf = RODEParamJacobianWrapper(unwrappedf, _t, y, _W)
if isscimlstructure(p)
pf = RODEParamJacobianWrapper(
(du, u, p, t, W)->unwrappedf(du, u, repack(p), t, W), _t, y, _W)
else
pf = RODEParamJacobianWrapper(unwrappedf, _t, y, _W)
end
end
paramjac_config = build_param_jac_config(
sensealg, pf, y, SciMLStructures.replace(Tunable(), p, tunables))
Expand Down Expand Up @@ -317,7 +328,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
elseif autojacvec isa Bool
if isinplace
if SciMLBase.is_diagonal_noise(prob)
pf = SciMLBase.ParamJacobianWrapper(unwrappedf, _t, y)
if isscimlstructure(p)
pf = SciMLBase.ParamJacobianWrapper(
(du, u, p, t)->unwrappedf(du, u, repack(p), t), _t, y)
else
pf = SciMLBase.ParamJacobianWrapper(unwrappedf, _t, y)
end
if isnoisemixing(sensealg)
uf = SciMLBase.UJacobianWrapper(unwrappedf, _t, p)
jac_noise_config = build_jac_config(sensealg, uf, u0)
Expand Down
24 changes: 19 additions & 5 deletions src/backsolve_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ end
u0 = state_values(sol.prob)
if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
else
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
else
throw(SciMLStructuresCompatibilityError())
end

## Force recompile mode until vjps are specialized to handle this!!!
Expand Down Expand Up @@ -263,7 +265,13 @@ end
(; f, tspan) = sol.prob
p = parameter_values(sol)
u0 = state_values(sol.prob)
tunables, repack, _ = canonicalize(Tunable(), p)
if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
else
throw(SciMLStructuresCompatibilityError())
end

# check if solution was terminated, then use reduced time span
terminated = false
Expand All @@ -283,7 +291,7 @@ end
error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!")

numstates = length(u0)
numparams = length(tunables)
numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables)

len = length(u0) + numparams
λ = one(eltype(u0)) .* similar(tunables, len)
Expand Down Expand Up @@ -386,7 +394,13 @@ end
(; f, tspan) = sol.prob
p = parameter_values(sol)
u0 = state_values(sol.prob)
tunables, repack, _ = canonicalize(Tunable(), p)
if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
else
throw(SciMLStructuresCompatibilityError())
end
# check if solution was terminated, then use reduced time span
terminated = false
if hasfield(typeof(sol), :retcode)
Expand All @@ -404,7 +418,7 @@ end
error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!")

numstates = length(u0)
numparams = length(tunables)
numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables)

len = length(u0) + numparams
λ = one(eltype(u0)) .* similar(tunables, len)
Expand Down
8 changes: 6 additions & 2 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,12 @@ function DiffEqBase._concrete_solve_adjoint(
save_idxs = nothing,
initializealg_default = SciMLBase.OverrideInit(; abstol = 1e-6, reltol = 1e-3),
kwargs...)
if !(sensealg isa GaussAdjoint) &&
!(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
# Check parameter compatibility for adjoint methods
if !((p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
(sensealg isa
Union{GaussAdjoint, BacksolveAdjoint, InterpolatingAdjoint, QuadratureAdjoint} &&
isscimlstructure(p)) ||
(sensealg isa Union{GaussAdjoint, QuadratureAdjoint} && isfunctor(p))) ||
(p isa AbstractArray && !Base.isconcretetype(eltype(p)))
throw(AdjointSensitivityParameterCompatibilityError())
end
Expand Down
42 changes: 36 additions & 6 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,19 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy,
pf.t = t
pf.u = y
if inplace_sensitivity(S)
jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
if isscimlstructure(p)
tunables, _, _ = canonicalize(Tunable(), p)
jacobian!(pJ, pf, tunables, f_cache, sensealg, paramjac_config)
else
jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
end
else
temp = jacobian(pf, p, sensealg)
if isscimlstructure(p)
tunables, _, _ = canonicalize(Tunable(), p)
temp = jacobian(pf, tunables, sensealg)
else
temp = jacobian(pf, p, sensealg)
end
pJ .= temp
end
end
Expand All @@ -319,9 +329,19 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy,
pf.u = y
pf.W = W
if inplace_sensitivity(S)
jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
if isscimlstructure(p)
tunables, _, _ = canonicalize(Tunable(), p)
jacobian!(pJ, pf, tunables, f_cache, sensealg, paramjac_config)
else
jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_config)
end
else
temp = jacobian(pf, p, sensealg)
if isscimlstructure(p)
tunables, _, _ = canonicalize(Tunable(), p)
temp = jacobian(pf, tunables, sensealg)
else
temp = jacobian(pf, p, sensealg)
end
pJ .= temp
end
end
Expand Down Expand Up @@ -814,10 +834,20 @@ function _jacNoise!(λ, y, p, t, S::TS, isnoise::Bool, dgrad, dλ,
pf.t = t
pf.u = y
if inplace_sensitivity(S)
jacobian!(pJ, pf, p, nothing, sensealg, nothing)
if isscimlstructure(p)
tunables, _, _ = canonicalize(Tunable(), p)
jacobian!(pJ, pf, tunables, nothing, sensealg, nothing)
else
jacobian!(pJ, pf, p, nothing, sensealg, nothing)
end
#jacobian!(pJ, pf, p, f_cache, sensealg, paramjac_noise_config)
else
temp = jacobian(pf, p, sensealg)
if isscimlstructure(p)
tunables, _, _ = canonicalize(Tunable(), p)
temp = jacobian(pf, tunables, sensealg)
else
temp = jacobian(pf, p, sensealg)
end
pJ .= temp
end
end
Expand Down
4 changes: 3 additions & 1 deletion src/interpolating_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,10 @@ end

if p === nothing || p isa SciMLBase.NullParameters
tunables, repack = p, identity
else
elseif isscimlstructure(p)
tunables, repack, _ = canonicalize(Tunable(), p)
else
throw(SciMLStructuresCompatibilityError())
end

## Force recompile mode until vjps are specialized to handle this!!!
Expand Down
3 changes: 2 additions & 1 deletion src/sensitivity_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ function _adjoint_sensitivities(sol, sensealg, alg;
callback = nothing,
kwargs...)
mtkp = SymbolicIndexingInterface.parameter_values(sol)
if !(mtkp isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
if !((mtkp isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
isscimlstructure(mtkp) || isfunctor(mtkp)) ||
(mtkp isa AbstractArray && !Base.isconcretetype(eltype(mtkp)))
throw(AdjointSensitivityParameterCompatibilityError())
end
Expand Down
24 changes: 22 additions & 2 deletions test/scimlstructures_interface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# taken from https://github.com/SciML/SciMLStructures.jl/pull/28
using OrdinaryDiffEq, SciMLSensitivity, Zygote
using LinearAlgebra
using Test
import SciMLStructures as SS

mutable struct SubproblemParameters{P, Q, R}
Expand Down Expand Up @@ -87,6 +88,7 @@ import SciMLStructures as SS
using Zygote
using ADTypes
using Test
using Tracker, ReverseDiff

mutable struct myparam{M, P, S}
model::M
Expand Down Expand Up @@ -156,6 +158,24 @@ function run_diff(ps, sensealg)
return sol.u |> last |> sum
end

## Test all adjoints with SciMLStructures

# Test basic functionality
run_diff(initialize())
@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint())[1].ps)
@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec = false))[1].ps)

@testset "SciMLStructures Support for All Adjoints" begin
# Test GaussAdjoint (already working)
@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint())[1].ps)

# Test newly fixed BacksolveAdjoint and InterpolatingAdjoint - these are the main fixes in this PR
@test !iszero(Zygote.gradient(run_diff, initialize(), BacksolveAdjoint())[1].ps)
@test !iszero(Zygote.gradient(run_diff, initialize(), InterpolatingAdjoint())[1].ps)

# Test QuadratureAdjoint (already working)
@test !iszero(Zygote.gradient(run_diff, initialize(), QuadratureAdjoint())[1].ps)

# Test with different AD backends
@test !iszero(Zygote.gradient(run_diff, initialize(), ReverseDiffAdjoint())[1].ps)
@test !iszero(Zygote.gradient(run_diff, initialize(), TrackerAdjoint())[1].ps)
@test !iszero(Zygote.gradient(run_diff, initialize(), ZygoteAdjoint())[1].ps)
end
Loading