Skip to content

Commit eb3ea6a

Browse files
Merge pull request #846 from SciML/original
Tag the original solution to sol.original and simplify dependencies
2 parents 49efc20 + 81792ea commit eb3ea6a

File tree

7 files changed

+37
-36
lines changed

7 files changed

+37
-36
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1212
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
13-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1413
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
1514
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1615
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -47,7 +46,6 @@ CUDA = "5.2"
4746
ChainRulesCore = "1.21"
4847
ComponentArrays = "0.15.8"
4948
Cubature = "1.5"
50-
DiffEqBase = "6.148"
5149
DiffEqNoiseProcess = "5.20"
5250
Distributions = "0.25.107"
5351
DocStringExtensions = "0.9"

src/BPINN_ode.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
55
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
66
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
7-
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
7+
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
88
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
99
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
1010
progress = false, verbose = false)
@@ -64,7 +64,7 @@ sol_lux_pestim = solve(prob, alg)
6464
6565
Note that the solution is evaluated at fixed time points according to the strategy chosen.
6666
ensemble solution is evaluated and given at steps of `saveat`.
67-
Dataset should only be provided when ODE parameter Estimation is being done.
67+
Dataset should only be provided when ODE parameter Estimation is being done.
6868
The neural network is a fully continuous solution so `BPINNsolution`
6969
is an accurate interpolation (up to the neural network training result). In addition, the
7070
`BPINNstats` is returned as `sol.fullsolution` for further analysis.
@@ -170,7 +170,7 @@ struct BPINNsolution{O <: BPINNstats, E, NP, OP, P}
170170
end
171171
end
172172

173-
function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
173+
function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
174174
alg::BNNODE,
175175
args...;
176176
dt = nothing,

src/NeuralPDE.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module NeuralPDE
55

66
using DocStringExtensions
77
using Reexport, Statistics
8-
@reexport using DiffEqBase
8+
@reexport using SciMLBase
99
@reexport using ModelingToolkit
1010

1111
using Zygote, ForwardDiff, Random, Distributions
@@ -16,7 +16,6 @@ using Integrals, Cubature
1616
using QuasiMonteCarlo: LatinHypercubeSample
1717
import QuasiMonteCarlo
1818
using RuntimeGeneratedFunctions
19-
using SciMLBase
2019
using Statistics
2120
using ArrayInterface
2221
import Optim

src/advancedHMC_MCMC.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
44
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
55
}
66
dim::Int
7-
prob::DiffEqBase.ODEProblem
7+
prob::SciMLBase.ODEProblem
88
chain::C
99
st::S
1010
strategy::ST
@@ -336,12 +336,12 @@ end
336336

337337
"""
338338
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining,
339-
dataset = [nothing],init_params = nothing,
339+
dataset = [nothing],init_params = nothing,
340340
draw_samples = 1000, physdt = 1 / 20.0f0,l2std = [0.05],
341341
phystd = [0.05], priorsNNw = (0.0, 2.0),
342342
param = [], nchains = 1, autodiff = false, Kernel = HMC,
343343
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
344-
Metric = DiagEuclideanMetric,
344+
Metric = DiagEuclideanMetric,
345345
targetacceptancerate = 0.8),
346346
Integratorkwargs = (Integrator = Leapfrog,),
347347
MCMCkwargs = (n_leapfrog = 30,),
@@ -431,7 +431,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
431431
432432
* AdvancedHMC.jl is still developing convenience structs so might need changes on new releases.
433433
"""
434-
function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
434+
function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
435435
strategy = GridTraining, dataset = [nothing],
436436
init_params = nothing, draw_samples = 1000,
437437
physdt = 1 / 20.0, l2std = [0.05],

src/dae_solve.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ of the physics-informed neural network which is used as a solver for a standard
3131
By default, `GridTraining` is used with `dt` if given.
3232
"""
3333
struct NNDAE{C, O, P, K, S <: Union{Nothing, AbstractTrainingStrategy}
34-
} <: DiffEqBase.AbstractDAEAlgorithm
34+
} <: SciMLBase.AbstractDAEAlgorithm
3535
chain::C
3636
opt::O
3737
init_params::P
@@ -79,7 +79,7 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
7979
return loss
8080
end
8181

82-
function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
82+
function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
8383
alg::NNDAE,
8484
args...;
8585
dt = nothing,
@@ -178,12 +178,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
178178
u = [phi(t, res.u) for t in ts]
179179
end
180180

181-
sol = DiffEqBase.build_solution(prob, alg, ts, u;
181+
sol = SciMLBase.build_solution(prob, alg, ts, u;
182182
k = res, dense = true,
183183
calculate_error = false,
184-
retcode = ReturnCode.Success)
185-
DiffEqBase.has_analytic(prob.f) &&
186-
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
184+
retcode = ReturnCode.Success,
185+
original = res,
186+
resid = res.objective)
187+
SciMLBase.has_analytic(prob.f) &&
188+
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
187189
dense_errors = false)
188190
sol
189191
end

src/ode_solve.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
1+
abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end
22

33
"""
44
NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...)
@@ -14,10 +14,10 @@ of the physics-informed neural network which is used as a solver for a standard
1414
1515
## Positional Arguments
1616
17-
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
17+
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
1818
`Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`.
1919
* `opt`: The optimizer to train the neural network.
20-
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
20+
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
2121
which thus uses the random initialization provided by the neural network library.
2222
2323
## Keyword Arguments
@@ -28,8 +28,8 @@ of the physics-informed neural network which is used as a solver for a standard
2828
automatic differentiation (via Zygote), this is only for the derivative
2929
in the loss function (the derivative with respect to time).
3030
* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values
31-
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
32-
`false` means which means the application of the neural network is done at individual time points one at a time.
31+
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
32+
`false` means which means the application of the neural network is done at individual time points one at a time.
3333
This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand.
3434
* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network.
3535
* `strategy`: The training strategy used to choose the points for the evaluations.
@@ -339,7 +339,7 @@ end
339339
SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation"
340340
SciMLBase.allowscomplex(::NNODE) = true
341341

342-
function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
342+
function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
343343
alg::NNODE,
344344
args...;
345345
dt = nothing,
@@ -479,13 +479,15 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
479479
u = [phi(t, res.u) for t in ts]
480480
end
481481

482-
sol = DiffEqBase.build_solution(prob, alg, ts, u;
482+
sol = SciMLBase.build_solution(prob, alg, ts, u;
483483
k = res, dense = true,
484484
interp = NNODEInterpolation(phi, res.u),
485485
calculate_error = false,
486-
retcode = ReturnCode.Success)
487-
DiffEqBase.has_analytic(prob.f) &&
488-
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
486+
retcode = ReturnCode.Success,
487+
original = res,
488+
resid = res.objective)
489+
SciMLBase.has_analytic(prob.f) &&
490+
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
489491
dense_errors = false)
490492
sol
491493
end #solve

src/rode_solve.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function NNRODE(chain, W, opt = Optim.BFGS(), init_params = nothing; autodiff =
2020
NNRODE(chain, W, opt, init_params, autodiff, kwargs)
2121
end
2222

23-
function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
23+
function SciMLBase.solve(prob::SciMLBase.AbstractRODEProblem,
2424
alg::NeuralPDEAlgorithm,
2525
args...;
2626
dt,
@@ -30,7 +30,7 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
3030
abstol = 1.0f-6,
3131
verbose = false,
3232
maxiters = 100)
33-
DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!")
33+
SciMLBase.isinplace(prob) && error("Only out-of-place methods are allowed!")
3434

3535
u0 = prob.u0
3636
tspan = prob.tspan
@@ -52,24 +52,24 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
5252
if u0 isa Number
5353
phi = (t, W, θ) -> u0 +
5454
(t - tspan[1]) *
55-
first(chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]),
55+
first(chain(adapt(SciMLBase.parameterless_type(θ), [t, W]),
5656
θ))
5757
else
5858
phi = (t, W, θ) -> u0 +
5959
(t - tspan[1]) *
60-
chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), θ)
60+
chain(adapt(SciMLBase.parameterless_type(θ), [t, W]), θ)
6161
end
6262
else
6363
_, re = Flux.destructure(chain)
6464
#The phi trial solution
6565
if u0 isa Number
6666
phi = (t, W, θ) -> u0 +
6767
(t - t0) *
68-
first(re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W])))
68+
first(re(θ)(adapt(SciMLBase.parameterless_type(θ), [t, W])))
6969
else
7070
phi = (t, W, θ) -> u0 +
7171
(t - t0) *
72-
re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W]))
72+
re(θ)(adapt(SciMLBase.parameterless_type(θ), [t, W]))
7373
end
7474
end
7575

@@ -108,9 +108,9 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
108108
u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)]
109109
end
110110

111-
sol = DiffEqBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false)
112-
DiffEqBase.has_analytic(prob.f) &&
113-
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
111+
sol = SciMLBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false)
112+
SciMLBase.has_analytic(prob.f) &&
113+
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
114114
dense_errors = false)
115115
sol
116116
end #solve

0 commit comments

Comments
 (0)