From 4638d8ba76e7971ac64d0cc35cdfd3365185ccf4 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 24 Apr 2025 23:40:13 -0400 Subject: [PATCH 1/4] fix adjoint constructor for nonlinearsolution --- ext/SciMLBaseZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 45a8e0f63..b22a5b051 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -228,7 +228,7 @@ end uType2 } function NonlinearSolutionAdjoint(ȳ) - (ȳ, ntuple(_ -> nothing, length(args))...) + (ȳ.u, ntuple(_ -> nothing, length(args))...) end NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint end From 1b2a7bcaecd7c3ad8135dd770f2f63364dc5a705 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 24 Apr 2025 23:40:32 -0400 Subject: [PATCH 2/4] add rrule for nonlinearsolution constructor --- ext/SciMLBaseChainRulesCoreExt.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 028d505ba..3c1b34341 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -116,4 +116,15 @@ function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged out, EnsembleSolution_adjoint end +function ChainRulesCore.rrule( + ::Type{<:SciMLBase.NonlinearSolution{ + T, N, uType, R, P, A, O, uType2, S, Tr}}, u, + args...) where {T, N, uType, R, P, A, O, uType2, S, Tr} + function NonlinearSolutionAdjoint(ȳ) + (NoTangent(), ȳ.u, ntuple(_ -> NoTangent(), length(args))...) + end + SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, args...), + NonlinearSolutionAdjoint +end + end From d948f07cbc6356544f7ab6153a037cfd2563cede Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 14 May 2025 10:27:01 -0400 Subject: [PATCH 3/4] add prob tracking --- ext/SciMLBaseChainRulesCoreExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 3c1b34341..5a0376355 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -118,12 +118,12 @@ end function ChainRulesCore.rrule( ::Type{<:SciMLBase.NonlinearSolution{ - T, N, uType, R, P, A, O, uType2, S, Tr}}, u, + T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob, args...) where {T, N, uType, R, P, A, O, uType2, S, Tr} function NonlinearSolutionAdjoint(ȳ) - (NoTangent(), ȳ.u, ntuple(_ -> NoTangent(), length(args))...) + (NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...) end - SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, args...), + SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, resid, prob, args...), NonlinearSolutionAdjoint end From 1d7f7b95cb45fe1bee912deec28df23cd4d9cc1f Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 14 May 2025 13:07:55 -0400 Subject: [PATCH 4/4] fix hat vs bar --- ext/SciMLBaseChainRulesCoreExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index 5a0376355..71d14c8a9 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -121,7 +121,7 @@ function ChainRulesCore.rrule( T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob, args...) where {T, N, uType, R, P, A, O, uType2, S, Tr} function NonlinearSolutionAdjoint(ȳ) - (NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...) + (NoTangent(), ȳ.u, NoTangent(), ȳ.prob, ntuple(_ -> NoTangent(), length(args))...) end SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, resid, prob, args...), NonlinearSolutionAdjoint