Skip to content

Commit cccb273

Browse files
Merge pull request #161 from lxvm/pr_init
draft init interface
2 parents ef0edaf + 60e6d98 commit cccb273

File tree

2 files changed

+69
-34
lines changed

2 files changed

+69
-34
lines changed

src/Integrals.jl

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Reexport, MonteCarloIntegration, QuadGK, HCubature
88
@reexport using SciMLBase
99
using LinearAlgebra
1010

11+
include("common.jl")
1112
include("init.jl")
1213
include("algorithms.jl")
1314
include("infinity_handling.jl")
@@ -56,40 +57,6 @@ function checkkwargs(kwargs...)
5657
end
5758
return nothing
5859
end
59-
"""
60-
```julia
61-
solve(prob::IntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm; kwargs...)
62-
```
63-
64-
## Keyword Arguments
65-
66-
The arguments to `solve` are common across all of the quadrature methods.
67-
These common arguments are:
68-
69-
- `maxiters` (the maximum number of iterations)
70-
- `abstol` (absolute tolerance in changes of the objective value)
71-
- `reltol` (relative tolerance in changes of the objective value)
72-
"""
73-
function SciMLBase.solve(prob::IntegralProblem,
74-
alg::SciMLBase.AbstractIntegralAlgorithm;
75-
sensealg = ReCallVJP(ZygoteVJP()),
76-
do_inf_transformation = nothing, kwargs...)
77-
checkkwargs(kwargs...)
78-
prob = transformation_if_inf(prob, do_inf_transformation)
79-
__solvebp(prob, alg, sensealg, prob.lb, prob.ub, prob.p; kwargs...)
80-
end
81-
# Throw error if alg is not provided, as defaults are not implemented.
82-
function SciMLBase.solve(::IntegralProblem; kwargs...)
83-
checkkwargs(kwargs...)
84-
throw(ArgumentError("""
85-
No integration algorithm `alg` was supplied as the second positional argument.
86-
Reccomended integration algorithms are:
87-
For scalar functions: QuadGKJL()
88-
For ≤ 8 dimensional vector functions: HCubatureJL()
89-
For > 8 dimensional vector functions: MonteCarloIntegration.vegas(f, st, en, kwargs...)
90-
See the docstrings of the different algorithms for more detail.
91-
"""))
92-
end
9360

9461
# Give a layer to intercept with AD
9562
__solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...)

src/common.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
struct IntegralCache{P, A, S, K, Tc}
2+
prob::P
3+
alg::A
4+
sensealg::S
5+
kwargs::K
6+
# cache for algorithm goes here (currently unused)
7+
cacheval::Tc
8+
isfresh::Bool
9+
end
10+
11+
function SciMLBase.init(prob::IntegralProblem,
12+
alg::SciMLBase.AbstractIntegralAlgorithm;
13+
sensealg = ReCallVJP(ZygoteVJP()),
14+
do_inf_transformation = nothing, kwargs...)
15+
checkkwargs(kwargs...)
16+
prob = transformation_if_inf(prob, do_inf_transformation)
17+
cacheval = nothing
18+
isfresh = true
19+
20+
IntegralCache{typeof(prob),
21+
typeof(alg),
22+
typeof(sensealg),
23+
typeof(kwargs),
24+
typeof(cacheval)}(prob,
25+
alg,
26+
sensealg,
27+
kwargs,
28+
cacheval,
29+
isfresh)
30+
end
31+
32+
# Throw error if alg is not provided, as defaults are not implemented.
33+
function SciMLBase.solve(::IntegralProblem; kwargs...)
34+
checkkwargs(kwargs...)
35+
throw(ArgumentError("""
36+
No integration algorithm `alg` was supplied as the second positional argument.
37+
Reccomended integration algorithms are:
38+
For scalar functions: QuadGKJL()
39+
For ≤ 8 dimensional vector functions: HCubatureJL()
40+
For > 8 dimensional vector functions: MonteCarloIntegration.vegas(f, st, en, kwargs...)
41+
See the docstrings of the different algorithms for more detail.
42+
"""))
43+
end
44+
45+
"""
46+
```julia
47+
solve(prob::IntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm; kwargs...)
48+
```
49+
50+
## Keyword Arguments
51+
52+
The arguments to `solve` are common across all of the quadrature methods.
53+
These common arguments are:
54+
55+
- `maxiters` (the maximum number of iterations)
56+
- `abstol` (absolute tolerance in changes of the objective value)
57+
- `reltol` (relative tolerance in changes of the objective value)
58+
"""
59+
function SciMLBase.solve(prob::IntegralProblem,
60+
alg::SciMLBase.AbstractIntegralAlgorithm;
61+
kwargs...)
62+
solve!(init(prob, alg; kwargs...))
63+
end
64+
65+
function SciMLBase.solve!(cache::IntegralCache)
66+
prob = cache.prob
67+
__solvebp(prob, cache.alg, cache.sensealg, prob.lb, prob.ub, prob.p; cache.kwargs...)
68+
end

0 commit comments

Comments
 (0)