Skip to content

Commit 9af6239

Browse files
Merge pull request #197 from agerlach/divonne_options
Add missing Divonne options v2
2 parents 6cea3a8 + ba6119c commit 9af6239

File tree

8 files changed

+78
-71
lines changed

8 files changed

+78
-71
lines changed

ext/IntegralsCubaExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module IntegralsCubaExt
22

33
using Integrals, Cuba
4-
import Integrals: transformation_if_inf, scale_x, scale_x!, CubaVegas, AbstractCubaAlgorithm,
5-
CubaSUAVE, CubaDivonne, CubaCuhre
4+
import Integrals: transformation_if_inf,
5+
scale_x, scale_x!, CubaVegas, AbstractCubaAlgorithm,
6+
CubaSUAVE, CubaDivonne, CubaCuhre
67

78
function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgorithm,
89
sensealg,
@@ -116,4 +117,4 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori
116117
chi = out.probability, retcode = ReturnCode.Success)
117118
end
118119

119-
end
120+
end

ext/IntegralsCubatureExt.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,23 @@ module IntegralsCubatureExt
22

33
using Integrals, Cubature
44

5-
import Integrals: transformation_if_inf, scale_x, scale_x!, CubatureJLh, CubatureJLp,
6-
AbstractCubatureJLAlgorithm
5+
import Integrals: transformation_if_inf,
6+
scale_x, scale_x!, CubatureJLh, CubatureJLp,
7+
AbstractCubatureJLAlgorithm
78
import Cubature: INDIVIDUAL, PAIRED, L1, L2, LINF
89

910
Integrals.CubatureJLh(; error_norm = Cubature.INDIVIDUAL) = CubatureJLh(error_norm)
1011
Integrals.CubatureJLp(; error_norm = Cubature.INDIVIDUAL) = CubatureJLp(error_norm)
1112

1213
function Integrals.__solvebp_call(prob::IntegralProblem,
13-
alg::AbstractCubatureJLAlgorithm,
14-
sensealg, domain, p;
15-
reltol = 1e-8, abstol = 1e-8,
16-
maxiters = typemax(Int))
17-
14+
alg::AbstractCubatureJLAlgorithm,
15+
sensealg, domain, p;
16+
reltol = 1e-8, abstol = 1e-8,
17+
maxiters = typemax(Int))
1818
lb, ub = domain
1919
mid = (lb + ub) / 2
2020

21-
# we get to pick fdim or not based on the IntegralFunction and its output dimensions
21+
# we get to pick fdim or not based on the IntegralFunction and its output dimensions
2222
y = if prob.f isa BatchIntegralFunction
2323
isinplace(prob.f) ? prob.f.integrand_prototype :
2424
mid isa Number ? prob.f(eltype(mid)[], p) :
@@ -176,12 +176,12 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
176176
if prob.batch == 0
177177
if isinplace(prob)
178178
dx = zeros(eltype(lb), prob.nout)
179-
@@ -181,6 +334,7 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
179+
@@ -181,6 +334,7 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
180180
end
181181
end
182182
end
183183
=#
184184
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
185185
end
186186

187-
end
187+
end

ext/IntegralsForwardDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function Integrals.__solvebp(cache, alg, sensealg, domain,
8585
res = reinterpret(reshape, DT, dual.u)
8686
# unwrap the dual when the primal would return a scalar
8787
out = if (cache.f isa BatchIntegralFunction && y isa AbstractVector) ||
88-
!(y isa AbstractArray)
88+
!(y isa AbstractArray)
8989
only(res)
9090
else
9191
res

src/Integrals.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
169169
out = vegas(f, lb, ub, rtol = reltol, atol = abstol,
170170
maxiter = maxiters, nbins = alg.nbins, debug = alg.debug,
171171
ncalls = ncalls, batch = prob.f isa BatchIntegralFunction)
172-
val, err, chi = out isa Tuple ? out : (out.integral_estimate, out.standard_deviation, out.chi_squared_average)
172+
val, err, chi = out isa Tuple ? out :
173+
(out.integral_estimate, out.standard_deviation, out.chi_squared_average)
173174
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
174175
end
175176

src/algorithms.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ year={1981},
246246
publisher={ACM New York, NY, USA}
247247
}
248248
"""
249-
struct CubaDivonne{R1, R2, R3} <:
250-
AbstractCubaAlgorithm where {R1 <: Real, R2 <: Real, R3 <: Real}
249+
struct CubaDivonne{R1, R2, R3, R4} <:
250+
AbstractCubaAlgorithm where {R1 <: Real, R2 <: Real, R3 <: Real, R4 <: Real}
251251
flags::Int
252252
seed::Int
253253
minevals::Int
@@ -258,6 +258,9 @@ struct CubaDivonne{R1, R2, R3} <:
258258
border::R1
259259
maxchisq::R2
260260
mindeviation::R3
261+
xgiven::Matrix{R4}
262+
nextra::Int
263+
peakfinder::Ptr{Cvoid}
261264
end
262265
"""
263266
CubaCuhre()
@@ -293,9 +296,11 @@ function CubaSUAVE(; flags = 0, seed = 0, minevals = 0, nnew = 1000, nmin = 2,
293296
end
294297
function CubaDivonne(; flags = 0, seed = 0, minevals = 0,
295298
key1 = 47, key2 = 1, key3 = 1, maxpass = 5, border = 0.0,
296-
maxchisq = 10.0, mindeviation = 0.25)
299+
maxchisq = 10.0, mindeviation = 0.25,
300+
xgiven = zeros(Cdouble, 0, 0),
301+
nextra = 0, peakfinder = C_NULL)
297302
CubaDivonne(flags, seed, minevals, key1, key2, key3, maxpass, border, maxchisq,
298-
mindeviation)
303+
mindeviation, xgiven, nextra, peakfinder)
299304
end
300305
CubaCuhre(; flags = 0, minevals = 0, key = 0) = CubaCuhre(flags, minevals, key)
301306

@@ -325,7 +330,6 @@ struct CubatureJLh <: AbstractCubatureJLAlgorithm
325330
error_norm::Int32
326331
end
327332

328-
329333
"""
330334
CubatureJLp()
331335

test/interface_tests.jl

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,26 @@ max_nout_test = 2
77
reltol = 1e-3
88
abstol = 1e-3
99

10-
algs = [QuadGKJL(), HCubatureJL(), CubatureJLh(), CubatureJLp(), VEGAS(), #CubaVegas(),
11-
CubaSUAVE(), CubaDivonne(), CubaCuhre()]
10+
algs = [QuadGKJL, HCubatureJL, CubatureJLh, CubatureJLp, VEGAS, #CubaVegas,
11+
CubaSUAVE, CubaDivonne, CubaCuhre]
1212

13-
alg_req = Dict(QuadGKJL() => (nout = 1, allows_batch = false, min_dim = 1, max_dim = 1,
13+
alg_req = Dict(QuadGKJL => (nout = 1, allows_batch = false, min_dim = 1, max_dim = 1,
1414
allows_iip = false),
15-
HCubatureJL() => (nout = Inf, allows_batch = false, min_dim = 1,
15+
HCubatureJL => (nout = Inf, allows_batch = false, min_dim = 1,
1616
max_dim = Inf, allows_iip = true),
17-
VEGAS() => (nout = 1, allows_batch = true, min_dim = 2, max_dim = Inf,
17+
VEGAS => (nout = 1, allows_batch = true, min_dim = 2, max_dim = Inf,
1818
allows_iip = true),
19-
CubatureJLh() => (nout = Inf, allows_batch = true, min_dim = 1,
19+
CubatureJLh => (nout = Inf, allows_batch = true, min_dim = 1,
2020
max_dim = Inf, allows_iip = true),
21-
CubatureJLp() => (nout = Inf, allows_batch = true, min_dim = 1,
21+
CubatureJLp => (nout = Inf, allows_batch = true, min_dim = 1,
2222
max_dim = Inf, allows_iip = true),
23-
CubaVegas() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf,
23+
CubaVegas => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf,
2424
allows_iip = true),
25-
CubaSUAVE() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf,
25+
CubaSUAVE => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf,
2626
allows_iip = true),
27-
CubaDivonne() => (nout = Inf, allows_batch = true, min_dim = 2,
27+
CubaDivonne => (nout = Inf, allows_batch = true, min_dim = 2,
2828
max_dim = Inf, allows_iip = true),
29-
CubaCuhre() => (nout = Inf, allows_batch = true, min_dim = 2, max_dim = Inf,
29+
CubaCuhre => (nout = Inf, allows_batch = true, min_dim = 2, max_dim = Inf,
3030
allows_iip = true))
3131

3232
integrands = [
@@ -96,7 +96,7 @@ end
9696
for i in 1:length(integrands)
9797
prob = IntegralProblem(integrands[i], lb, ub)
9898
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
99-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
99+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
100100
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
101101
end
102102
end
@@ -110,11 +110,11 @@ end
110110
for dim in 1:max_dim_test
111111
lb, ub = (ones(dim), 3ones(dim))
112112
prob = IntegralProblem(integrands[i], lb, ub)
113-
if dim > req.max_dim || dim < req.min_dim || alg isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
113+
if dim > req.max_dim || dim < req.min_dim || alg() isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
114114
continue
115115
end
116116
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
117-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
117+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
118118
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
119119
end
120120
end
@@ -129,14 +129,14 @@ end
129129
for dim in 1:max_dim_test
130130
lb, ub = (ones(dim), 3ones(dim))
131131
prob = IntegralProblem(iip_integrands[i], lb, ub)
132-
if dim > req.max_dim || dim < req.min_dim || alg isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
132+
if dim > req.max_dim || dim < req.min_dim || alg() isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
133133
continue
134134
end
135135
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
136-
if alg isa HCubatureJL && dim == 1 # HCubature library requires finer tol to pass test. When requiring array outputs for iip integrands
137-
sol = solve(prob, alg, reltol = 1e-5, abstol = 1e-5)
136+
if alg() isa HCubatureJL && dim == 1 # HCubature library requires finer tol to pass test. When requiring array outputs for iip integrands
137+
sol = solve(prob, alg(), reltol = 1e-5, abstol = 1e-5)
138138
else
139-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
139+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
140140
end
141141
if sol.u isa Number
142142
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
@@ -159,7 +159,7 @@ end
159159
continue
160160
end
161161
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
162-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
162+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
163163
@test sol.u[1]exact_sol[i](dim, nout, lb, ub) rtol=1e-2
164164
end
165165
end
@@ -177,7 +177,7 @@ end
177177
continue
178178
end
179179
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
180-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
180+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
181181
if sol.u isa Number
182182
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
183183
else
@@ -201,7 +201,7 @@ end
201201
continue
202202
end
203203
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
204-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
204+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
205205
if sol.u isa Number
206206
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
207207
else
@@ -226,7 +226,7 @@ end
226226
continue
227227
end
228228
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
229-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
229+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
230230
if nout == 1
231231
@test sol.u[1]exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2
232232
else
@@ -245,14 +245,14 @@ end
245245
lb, ub = (ones(dim), 3ones(dim))
246246
for nout in 1:max_nout_test
247247
if dim > req.max_dim || dim < req.min_dim || req.nout < nout ||
248-
alg isa QuadGKJL || alg isa VEGAS
248+
alg() isa QuadGKJL || alg() isa VEGAS
249249
#QuadGKJL and VEGAS require numbers, not single element arrays
250250
continue
251251
end
252252
prob = IntegralProblem((x, p) -> integrands_v[i](x, p, nout), lb, ub,
253253
nout = nout)
254254
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
255-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
255+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
256256
if nout == 1
257257
@test sol.u[1]exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2
258258
else
@@ -274,14 +274,14 @@ end
274274
prob = IntegralProblem((dx, x, p) -> iip_integrands_v[i](dx, x, p, nout),
275275
lb, ub, nout = nout)
276276
if dim > req.max_dim || dim < req.min_dim || req.nout < nout ||
277-
alg isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
277+
alg() isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
278278
continue
279279
end
280280
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
281281
if alg isa HCubatureJL && dim == 1 # HCubature library requires finer tol to pass test. When requiring array outputs for iip integrands
282-
sol = solve(prob, alg, reltol = 1e-5, abstol = 1e-5)
282+
sol = solve(prob, alg(), reltol = 1e-5, abstol = 1e-5)
283283
else
284-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
284+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
285285
end
286286
if nout == 1
287287
@test sol.u[1]exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2
@@ -306,7 +306,7 @@ end
306306
continue
307307
end
308308
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
309-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
309+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
310310
@test sol.uexact_sol_v[i](dim, nout, lb, ub) rtol=1e-2
311311
end
312312
end
@@ -327,7 +327,7 @@ end
327327
continue
328328
end
329329
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
330-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
330+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
331331
@test sol.uexact_sol_v[i](dim, nout, lb, ub) rtol=1e-2
332332
end
333333
end
@@ -348,7 +348,7 @@ end
348348
continue
349349
end
350350
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
351-
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
351+
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
352352
@test sol.uexact_sol_v[i](dim, nout, lb, ub) rtol=1e-2
353353
end
354354
end
@@ -375,7 +375,7 @@ end
375375
end
376376
for i in 1:length(integrands)
377377
prob = IntegralProblem(integrands[i], lb, ub, p)
378-
cache = init(prob, alg, reltol = reltol, abstol = abstol)
378+
cache = init(prob, alg(), reltol = reltol, abstol = abstol)
379379
@test solve!(cache).uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
380380
cache.lb = lb = 0.5
381381
@test solve!(cache).uexact_sol[i](dim, nout, lb, ub) rtol=1e-2

test/nested_ad_tests.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,43 @@ my_function(x, p) = x^2 + p[1]^3 * x + p[2]^2
55
function my_integration(p)
66
my_problem = IntegralProblem(my_function, -1.0, 1.0, p)
77
# return solve(my_problem, HCubatureJL(), reltol=1e-3, abstol=1e-3) # Works
8-
return solve(my_problem, CubatureJLh(), reltol=1e-3, abstol=1e-3) # Errors
8+
return solve(my_problem, CubatureJLh(), reltol = 1e-3, abstol = 1e-3) # Errors
99
end
1010
my_solution = my_integration(my_parameters)
1111
@test ForwardDiff.jacobian(my_integration, my_parameters) == [0.0 8.0]
12-
@test ForwardDiff.jacobian(x->ForwardDiff.jacobian(my_integration, x), my_parameters) == [0.0 0.0
13-
0.0 4.0]
12+
@test ForwardDiff.jacobian(x -> ForwardDiff.jacobian(my_integration, x), my_parameters) ==
13+
[0.0 0.0
14+
0.0 4.0]
1415

15-
ff(x,p) = sum(sin.(x .* p))
16+
ff(x, p) = sum(sin.(x .* p))
1617
lb = ones(2)
1718
ub = 3ones(2)
18-
p = [1.5,2.0]
19+
p = [1.5, 2.0]
1920

2021
function testf(p)
21-
prob = IntegralProblem(ff,lb,ub,p)
22-
sin(solve(prob,CubaCuhre(),reltol=1e-6,abstol=1e-6)[1])
22+
prob = IntegralProblem(ff, lb, ub, p)
23+
sin(solve(prob, CubaCuhre(), reltol = 1e-6, abstol = 1e-6)[1])
2324
end
2425

25-
hp1 = FiniteDiff.finite_difference_hessian(testf,p)
26-
hp2 = ForwardDiff.hessian(testf,p)
27-
@test hp1 hp2 atol=1e-4
26+
hp1 = FiniteDiff.finite_difference_hessian(testf, p)
27+
hp2 = ForwardDiff.hessian(testf, p)
28+
@test hp1hp2 atol=1e-4
2829

29-
ff2(x,p) = x*p[1].+p[2]*p[3]
30-
lb =1.0
30+
ff2(x, p) = x * p[1] .+ p[2] * p[3]
31+
lb = 1.0
3132
ub = 3.0
3233
p = [2.0, 3.0, 4.0]
3334
_ff3 = BatchIntegralFunction(ff2)
34-
prob = IntegralProblem(_ff3,lb,ub,p)
35+
prob = IntegralProblem(_ff3, lb, ub, p)
3536

36-
function testf3(lb,ub,p; f=_ff3)
37-
prob = IntegralProblem(_ff3,lb,ub,p)
38-
solve(prob, CubatureJLh(); reltol=1e-3,abstol=1e-3)[1]
37+
function testf3(lb, ub, p; f = _ff3)
38+
prob = IntegralProblem(_ff3, lb, ub, p)
39+
solve(prob, CubatureJLh(); reltol = 1e-3, abstol = 1e-3)[1]
3940
end
4041

41-
dp1 = ForwardDiff.gradient(p->testf3(lb,ub,p),p)
42-
dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
43-
dp3 = FiniteDiff.finite_difference_gradient(p->testf3(lb,ub,p),p)
42+
dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
43+
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
44+
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)
4445

4546
@test dp1 dp3
4647
@test dp2 dp3

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ using Test
88
@time @safetestset "Gaussian Quadrature Tests" include("gaussian_quadrature_tests.jl")
99
@time @safetestset "Sampled Integration Tests" include("sampled_tests.jl")
1010
@time @safetestset "QuadratureFunction Tests" include("quadrule_tests.jl")
11-
@time @safetestset "Nested AD Tests" include("nested_ad_tests.jl")
11+
@time @safetestset "Nested AD Tests" include("nested_ad_tests.jl")

0 commit comments

Comments
 (0)