Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 5afc4ba

Browse files
committed
Test the new NonlinearSolveBase.jl
1 parent fd7d216 commit 5afc4ba

14 files changed

+644
-99
lines changed

Manifest.toml

Lines changed: 574 additions & 0 deletions
Large diffs are not rendered by default.

Project.toml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,42 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.5.0"
4+
version = "1.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
10-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1110
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1211
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1312
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1413
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1514
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1615
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
16+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1717
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1818
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1919
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2020
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2121

2222
[weakdeps]
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
24+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2425
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2526
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2627
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2728

2829
[extensions]
29-
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
30+
SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt = ["ChainRulesCore", "DiffEqBase"]
3031
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
3132
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
3233
SimpleNonlinearSolveZygoteExt = "Zygote"
3334

3435
[compat]
3536
ADTypes = "0.2.6"
3637
AllocCheck = "0.1.1"
37-
ArrayInterface = "7.7"
3838
Aqua = "0.8"
39+
ArrayInterface = "7.7"
3940
CUDA = "5.2"
4041
ChainRulesCore = "1.22"
4142
ConcreteStructs = "0.2.3"
@@ -48,6 +49,7 @@ LinearAlgebra = "1.10"
4849
LinearSolve = "2.25"
4950
MaybeInplace = "0.1.1"
5051
NonlinearProblemLibrary = "0.1.2"
52+
NonlinearSolveBase = "1"
5153
Pkg = "1.10"
5254
PolyesterForwardDiff = "0.1.1"
5355
PrecompileTools = "1.2"
@@ -66,12 +68,12 @@ julia = "1.10"
6668
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
6769
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6870
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
69-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
7071
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
7172
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7273
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7374
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
7475
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
76+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
7577
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7678
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
7779
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -83,4 +85,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8385
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8486

8587
[targets]
86-
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]
88+
test = ["Aqua", "AllocCheck", "NonlinearSolveBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]

ext/SimpleNonlinearSolveChainRulesCoreExt.jl renamed to ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module SimpleNonlinearSolveChainRulesCoreExt
1+
module SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt
22

33
using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
44

src/SimpleNonlinearSolve.jl

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@ module SimpleNonlinearSolve
33
import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
44

55
@recompile_invalidations begin
6-
using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff,
7-
ForwardDiff, Reexport, LinearAlgebra, SciMLBase
8-
9-
import DiffEqBase: AbstractNonlinearTerminationMode,
10-
AbstractSafeNonlinearTerminationMode,
11-
AbstractSafeBestNonlinearTerminationMode,
12-
NonlinearSafeTerminationReturnCode, get_termination_mode,
13-
NONLINEARSOLVE_DEFAULT_NORM
6+
using ADTypes, ArrayInterface, FiniteDiff, ForwardDiff, NonlinearSolveBase, Reexport,
7+
LinearAlgebra, SciMLBase
8+
9+
import ConcreteStructs: @concrete
1410
import DiffResults
11+
import FastClosures: @closure
1512
import ForwardDiff: Dual
1613
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
14+
import NonlinearSolveBase: AbstractNonlinearTerminationMode,
15+
AbstractSafeNonlinearTerminationMode,
16+
AbstractSafeBestNonlinearTerminationMode,
17+
get_termination_mode, NONLINEARSOLVE_DEFAULT_NORM
1718
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
1819
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size
1920
end
2021

21-
@reexport using ADTypes, SciMLBase
22+
@reexport using ADTypes, SciMLBase # TODO: Reexport NonlinearSolveBase after the situation with NonlinearSolve.jl is resolved
2223

2324
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
2425
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
@@ -58,23 +59,28 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...;
5859
end
5960

6061
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
61-
function SciMLBase.solve(
62-
prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
63-
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
64-
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
65-
sensealg = prob.kwargs[:sensealg]
66-
end
67-
new_u0 = u0 !== nothing ? u0 : prob.u0
68-
new_p = p !== nothing ? p : prob.p
69-
return __internal_solve_up(
70-
prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
71-
alg, args...; prob.kwargs..., kwargs...)
72-
end
62+
# Using eval to prevent ambiguity
63+
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
64+
@eval begin
65+
function SciMLBase.solve(
66+
prob::$(pType), alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
67+
sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
68+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
69+
sensealg = prob.kwargs[:sensealg]
70+
end
71+
new_u0 = u0 !== nothing ? u0 : prob.u0
72+
new_p = p !== nothing ? p : prob.p
73+
return __internal_solve_up(
74+
prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
75+
alg, args...; prob.kwargs..., kwargs...)
76+
end
7377

74-
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, p,
75-
p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
76-
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
77-
return SciMLBase.__solve(prob, alg, args...; kwargs...)
78+
function __internal_solve_up(_prob::$(pType), sensealg, u0, u0_changed, p,
79+
p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
80+
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
81+
return SciMLBase.__solve(prob, alg, args...; kwargs...)
82+
end
83+
end
7884
end
7985

8086
@setup_workload begin

src/bracketing/bisection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
2626
left, right = prob.tspan
2727
fl, fr = f(left), f(right)
2828

29-
abstol = __get_tolerance(nothing, abstol,
29+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
3030
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
3131

3232
if iszero(fl)

src/bracketing/brent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
1313
fl, fr = f(left), f(right)
1414
ϵ = eps(convert(typeof(fl), 1))
1515

16-
abstol = __get_tolerance(nothing, abstol,
16+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
1717
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1818

1919
if iszero(fl)

src/bracketing/falsi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = __get_tolerance(nothing, abstol,
15+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/bracketing/itp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
5858
left, right = prob.tspan
5959
fl, fr = f(left), f(right)
6060

61-
abstol = __get_tolerance(nothing, abstol,
61+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
6262
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
6363

6464
if iszero(fl)

src/bracketing/ridder.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = __get_tolerance(nothing, abstol,
15+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/linesearch.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function (cache::LiFukushimaLineSearchCache)(u, δu)
7373
fx_norm = ϕ(T(0))
7474

7575
# Non-Blocking exit if the norm is NaN or Inf
76-
DiffEqBase.NAN_CHECK(fx_norm) && return cache.α
76+
NonlinearSolveBase.NAN_CHECK(fx_norm) && return cache.α
7777

7878
# Early Terminate based on Eq. 2.7
7979
du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu)
@@ -84,12 +84,12 @@ function (cache::LiFukushimaLineSearchCache)(u, δu)
8484
fxλp_norm = ϕ(λ₂)
8585

8686
if cache.nan_maxiters !== nothing
87-
if DiffEqBase.NAN_CHECK(fxλp_norm)
87+
if NonlinearSolveBase.NAN_CHECK(fxλp_norm)
8888
nan_converged = false
8989
for _ in 1:(cache.nan_maxiters)
9090
λ₁, λ₂ = λ₂, cache.β * λ₂
9191
fxλp_norm = ϕ(λ₂)
92-
nan_converged = DiffEqBase.NAN_CHECK(fxλp_norm)::Bool
92+
nan_converged = NonlinearSolveBase.NAN_CHECK(fxλp_norm)::Bool
9393
nan_converged && break
9494
end
9595
nan_converged || return cache.α

0 commit comments

Comments
 (0)