Skip to content

Commit 6b46c42

Browse files
Merge pull request #168 from lxvm/cache_quadgk
Cache for QuadGKJL
2 parents 646abf1 + 28a7789 commit 6b46c42

16 files changed

+415
-258
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Distributions = "0.23, 0.24, 0.25"
2020
ForwardDiff = "0.10"
2121
HCubature = "1.4"
2222
MonteCarloIntegration = "0.0.1, 0.0.2, 0.0.3"
23-
QuadGK = "2.1"
23+
QuadGK = "2.5"
2424
Reexport = "0.2, 1.0"
2525
Requires = "1"
2626
SciMLBase = "1.70"

docs/make.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
66
include("pages.jl")
77

88
makedocs(sitename = "Integrals.jl",
9-
authors = "Chris Rackauckas",
10-
modules = [Integrals, Integrals.SciMLBase],
11-
clean = true, doctest = false,
12-
strict = [
13-
:doctest,
14-
:linkcheck,
15-
:parse_error,
16-
:example_block,
17-
# Other available options are
18-
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
19-
],
20-
format = Documenter.HTML(analytics = "UA-90474609-3",
21-
assets = ["assets/favicon.ico"],
22-
canonical = "https://docs.sciml.ai/Integrals/stable/"),
23-
pages = pages)
9+
authors = "Chris Rackauckas",
10+
modules = [Integrals, Integrals.SciMLBase],
11+
clean = true, doctest = false,
12+
strict = [
13+
:doctest,
14+
:linkcheck,
15+
:parse_error,
16+
:example_block,
17+
# Other available options are
18+
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
19+
],
20+
format = Documenter.HTML(analytics = "UA-90474609-3",
21+
assets = ["assets/favicon.ico"],
22+
canonical = "https://docs.sciml.ai/Integrals/stable/"),
23+
pages = pages)
2424

2525
deploydocs(repo = "github.com/SciML/Integrals.jl.git";
26-
push_preview = true)
26+
push_preview = true)

docs/pages.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
pages = ["index.md",
22
"Tutorials" => Any["tutorials/numerical_integrals.md",
3-
"tutorials/differentiating_integrals.md"],
3+
"tutorials/differentiating_integrals.md"],
44
"Basics" => Any["basics/IntegralProblem.md",
5-
"basics/solve.md",
6-
"basics/FAQ.md"],
5+
"basics/solve.md",
6+
"basics/FAQ.md"],
77
"Solvers" => Any["solvers/IntegralSolvers.md"],
88
]

docs/src/solvers/IntegralSolvers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The following algorithms are available:
77
- `VEGAS`: Uses MonteCarloIntegration.jl. Requires `nout=1`. Works only for `>1`-dimensional integrations.
88
- `CubatureJLh`: h-Cubature from Cubature.jl. Requires `using IntegralsCubature`.
99
- `CubatureJLp`: p-Cubature from Cubature.jl. Requires `using IntegralsCubature`.
10-
- `CubaVegas`: Vegas from Cuba.jl. Requires `using IntegralsCuba`, `nout=1`.
10+
- `CubaVegas`: Vegas from Cuba.jl. Requires `using IntegralsCuba`, `nout=1`.
1111
- `CubaSUAVE`: SUAVE from Cuba.jl. Requires `using IntegralsCuba`.
1212
- `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations.
1313
- `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Integrals with Caching Interface
2+
3+
Often, integral solvers allocate memory or reuse quadrature rules for solving different
4+
problems. For example, if one is going to solve the same integral for several parameters
5+
6+
```julia
7+
using Integrals
8+
9+
prob = IntegralProblem((x, p) -> sin(x * p), 0, 1, 14.0)
10+
alg = QuadGKJL()
11+
12+
solve(prob, alg)
13+
14+
prob = remake(prob, p = 15.0)
15+
solve(prob, alg)
16+
```
17+
18+
then it would be more efficient to allocate the heap used by `quadgk` across several calls,
19+
shown below by directly calling the library
20+
21+
```julia
22+
using QuadGK
23+
segbuf = QuadGK.alloc_segbuf()
24+
quadgk(x -> sin(14x), 0, 1, segbuf = segbuf)
25+
quadgk(x -> sin(15x), 0, 1, segbuf = segbuf)
26+
```
27+
28+
Integrals.jl's caching interface automates this process to reuse resources if an algorithm
29+
supports it and if the necessary types to build the cache can be inferred from `prob`. To do
30+
this with Integrals.jl, you simply `init` a cache, `solve!`, replace `p`, and solve again.
31+
This uses the [SciML `init` interface](https://docs.sciml.ai/SciMLBase/stable/interfaces/Init_Solve/#init-and-the-Iterator-Interface)
32+
33+
```@example cache1
34+
using Integrals
35+
36+
prob = IntegralProblem((x, p) -> sin(x * p), 0, 1, 14.0)
37+
alg = QuadGKJL()
38+
39+
cache = init(prob, alg)
40+
sol1 = solve!(cache)
41+
```
42+
43+
```@example cache1
44+
cache.p = 15.0
45+
sol2 = solve!(cache)
46+
```
47+
48+
The caching interface is intended for updating `p`, `lb`, `ub`, `nout`, and `batch`.
49+
Note that the types of these variables is not allowed to change.
50+
If it is necessary to change the integrand `f` instead of defining a new
51+
`IntegralProblem`, consider using
52+
[FunctionWrappers.jl](https://github.com/yuyichao/FunctionWrappers.jl).

ext/IntegralsFastGaussQuadratureExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,20 @@ function composite_gauss_legendre(f, p, lb, ub, nodes, weights, subintervals)
3030
end
3131

3232
function Integrals.__solvebp_call(prob::IntegralProblem, alg::Integrals.GaussLegendre{C},
33-
sensealg, lb, ub, p;
34-
reltol = nothing, abstol = nothing,
35-
maxiters = nothing) where {C}
33+
sensealg, lb, ub, p;
34+
reltol = nothing, abstol = nothing,
35+
maxiters = nothing) where {C}
3636
if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray
3737
error("GaussLegendre only accepts one-dimensional quadrature problems.")
3838
end
3939
@assert prob.batch == 0
4040
@assert prob.nout == 1
4141
if C
4242
val = composite_gauss_legendre(prob.f, prob.p, lb, ub,
43-
alg.nodes, alg.weights, alg.subintervals)
43+
alg.nodes, alg.weights, alg.subintervals)
4444
else
4545
val = gauss_legendre(prob.f, prob.p, lb, ub,
46-
alg.nodes, alg.weights)
46+
alg.nodes, alg.weights)
4747
end
4848
err = nothing
4949
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)

ext/IntegralsForwardDiffExt.jl

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,36 @@ isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
44
### Forward-Mode AD Intercepts
55

66
# Direct AD on solvers with QuadGK and HCubature
7-
function Integrals.__solvebp(prob, alg::QuadGKJL, sensealg, lb, ub,
8-
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
9-
kwargs...) where {T, V, P, N}
10-
Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
7+
function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, lb, ub,
8+
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
9+
kwargs...) where {T, V, P, N}
10+
Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)
1111
end
1212

13-
function Integrals.__solvebp(prob, alg::HCubatureJL, sensealg, lb, ub,
14-
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
15-
kwargs...) where {T, V, P, N}
16-
Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
13+
function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, lb, ub,
14+
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
15+
kwargs...) where {T, V, P, N}
16+
Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)
1717
end
1818

1919
# Manually split for the pushforward
20-
function Integrals.__solvebp(prob, alg, sensealg, lb, ub,
21-
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
22-
kwargs...) where {T, V, P, N}
23-
primal = Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, ForwardDiff.value.(p);
24-
kwargs...)
20+
function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
21+
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
22+
kwargs...) where {T, V, P, N}
23+
primal = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, ForwardDiff.value.(p);
24+
kwargs...)
2525

26-
nout = prob.nout * P
26+
nout = cache.nout * P
2727

28-
if isinplace(prob)
28+
if isinplace(cache)
2929
dfdp = function (out, x, p)
3030
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
31-
if prob.batch > 0
32-
dx = similar(dualp, prob.nout, size(x, 2))
31+
if cache.batch > 0
32+
dx = similar(dualp, cache.nout, size(x, 2))
3333
else
34-
dx = similar(dualp, prob.nout)
34+
dx = similar(dualp, cache.nout)
3535
end
36-
prob.f(dx, x, dualp)
36+
cache.f(dx, x, dualp)
3737

3838
ys = reinterpret(ForwardDiff.Dual{T, V, P}, dx)
3939
idx = 0
@@ -47,8 +47,8 @@ function Integrals.__solvebp(prob, alg, sensealg, lb, ub,
4747
else
4848
dfdp = function (x, p)
4949
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
50-
ys = prob.f(x, dualp)
51-
if prob.batch > 0
50+
ys = cache.f(x, dualp)
51+
if cache.batch > 0
5252
out = similar(p, V, nout, size(x, 2))
5353
else
5454
out = similar(p, V, nout)
@@ -64,12 +64,20 @@ function Integrals.__solvebp(prob, alg, sensealg, lb, ub,
6464
return out
6565
end
6666
end
67+
6768
rawp = copy(reinterpret(V, p))
6869

69-
dp_prob = IntegralProblem(dfdp, lb, ub, rawp; nout = nout, batch = prob.batch,
70-
kwargs...)
71-
dual = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, rawp; kwargs...)
72-
res = similar(p, prob.nout)
70+
prob = Integrals.build_problem(cache)
71+
dp_prob = remake(prob, f = dfdp, nout = nout, p = rawp)
72+
# the infinity transformation was already applied to f so we don't apply it to dfdp
73+
dp_cache = init(dp_prob,
74+
alg;
75+
sensealg = sensealg,
76+
do_inf_transformation = Val(false),
77+
cache.kwargs...)
78+
dual = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, rawp; kwargs...)
79+
80+
res = similar(p, cache.nout)
7381
partials = reinterpret(typeof(first(res).partials), dual.u)
7482
for idx in eachindex(res)
7583
res[idx] = ForwardDiff.Dual{T, V, P}(primal.u[idx], partials[idx])

ext/IntegralsZygoteExt.jl

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@ else
1111
end
1212
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)
1313

14-
function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg, lb, ub, p;
15-
kwargs...)
16-
out = Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
14+
function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub,
15+
p;
16+
kwargs...)
17+
out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)
18+
1719
function quadrature_adjoint(Δ)
1820
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ
19-
if isinplace(prob)
20-
dx = zeros(prob.nout)
21-
_f = x -> prob.f(dx, x, p)
21+
if isinplace(cache)
22+
dx = zeros(cache.nout)
23+
_f = x -> cache.f(dx, x, p)
2224
if sensealg.vjp isa Integrals.ZygoteVJP
2325
dfdp = function (dx, x, p)
2426
_, back = Zygote.pullback(p) do p
25-
_dx = Zygote.Buffer(x, prob.nout, size(x, 2))
26-
prob.f(_dx, x, p)
27+
_dx = Zygote.Buffer(x, cache.nout, size(x, 2))
28+
cache.f(_dx, x, p)
2729
copy(_dx)
2830
end
2931

@@ -38,11 +40,11 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg
3840
error("TODO")
3941
end
4042
else
41-
_f = x -> prob.f(x, p)
43+
_f = x -> cache.f(x, p)
4244
if sensealg.vjp isa Integrals.ZygoteVJP
43-
if prob.batch > 0
45+
if cache.batch > 0
4446
dfdp = function (x, p)
45-
_, back = Zygote.pullback(p -> prob.f(x, p), p)
47+
_, back = Zygote.pullback(p -> cache.f(x, p), p)
4648

4749
out = zeros(length(p), size(x, 2))
4850
z = zeros(size(x, 2))
@@ -55,7 +57,7 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg
5557
end
5658
else
5759
dfdp = function (x, p)
58-
_, back = Zygote.pullback(p -> prob.f(x, p), p)
60+
_, back = Zygote.pullback(p -> cache.f(x, p), p)
5961
back(y)[1]
6062
end
6163
end
@@ -65,12 +67,19 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg
6567
end
6668
end
6769

68-
dp_prob = remake(prob, f = dfdp, lb = lb, ub = ub, p = p, nout = length(p))
70+
prob = Integrals.build_problem(cache)
71+
dp_prob = remake(prob, f = dfdp, nout = length(p))
72+
# the infinity transformation was already applied to f so we don't apply it to dfdp
73+
dp_cache = init(dp_prob,
74+
alg;
75+
sensealg = sensealg,
76+
do_inf_transformation = Val(false),
77+
cache.kwargs...)
6978

7079
if p isa Number
71-
dp = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, p; kwargs...)[1]
80+
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1]
7281
else
73-
dp = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, p; kwargs...).u
82+
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u
7483
end
7584

7685
if lb isa Number
@@ -79,14 +88,14 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg
7988
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp)
8089
else
8190
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(),
82-
NoTangent(), dp)
91+
NoTangent(), dp)
8392
end
8493
end
8594
out, quadrature_adjoint
8695
end
8796

8897
Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution,
89-
::Val{:u})
98+
::Val{:u})
9099
sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),)
91100
end
92101
end

0 commit comments

Comments
 (0)