|
1 | | -using NLPModels |
| 1 | +using NLPModels, LinearAlgebra, ExaModels |
2 | 2 |
|
3 | 3 | @testset "jac/jact/jvp/vjp" begin |
4 | 4 | function _check_consistency(sens; atol = 1e-8) |
@@ -91,3 +91,51 @@ using NLPModels |
91 | 91 | optimize!(model_rect) |
92 | 92 | _check_consistency(MadDiff.MadDiffSolver(unsafe_backend(model_rect).inner.solver); atol = 1e-8) |
93 | 93 | end |
| 94 | + |
| 95 | +@testset "ExaModels JVP/VJP vs FiniteDiff" begin |
| 96 | + p0 = [1.0, 3.0] |
| 97 | + h = 1e-5 |
| 98 | + atol = sqrt(h) |
| 99 | + |
| 100 | + function _make_exa(p_vals) |
| 101 | + c = ExaCore() |
| 102 | + p = ExaModels.parameter(c, p_vals) |
| 103 | + x = ExaModels.variable(c, 2) |
| 104 | + ExaModels.objective(c, x[1]^2 + x[2]^2 + p[1] * x[1]) |
| 105 | + ExaModels.constraint(c, x[1] + x[2] - p[2]) |
| 106 | + return ExaModel(c) |
| 107 | + end |
| 108 | + |
| 109 | + function _solve_exa(p_vals) |
| 110 | + m = _make_exa(p_vals) |
| 111 | + solver = MadNLP.MadNLPSolver(m; print_level = MadNLP.ERROR) |
| 112 | + MadNLP.solve!(solver) |
| 113 | + return solver |
| 114 | + end |
| 115 | + |
| 116 | + solver0 = _solve_exa(p0) |
| 117 | + sens = MadDiffSolver(solver0) |
| 118 | + x0 = Vector(MadNLP.variable(solver0.x)) |
| 119 | + y0 = Vector(solver0.y) |
| 120 | + n_p = sens.n_p |
| 121 | + |
| 122 | + for j in 1:n_p |
| 123 | + Δp = zeros(n_p); Δp[j] = 1.0 |
| 124 | + jvp = MadDiff.jacobian_vector_product!(sens, Δp) |
| 125 | + s_plus = _solve_exa(p0 .+ h .* Δp) |
| 126 | + @test isapprox(jvp.dx, (Vector(MadNLP.variable(s_plus.x)) .- x0) ./ h; atol) |
| 127 | + @test isapprox(jvp.dy, (Vector(s_plus.y) .- y0) ./ h; atol) |
| 128 | + end |
| 129 | + |
| 130 | + rng = MersenneTwister(42) |
| 131 | + dL_dx = randn(rng, length(x0)) |
| 132 | + dL_dy = randn(rng, length(y0)) |
| 133 | + vjp = MadDiff.vector_jacobian_product!(sens; dL_dx, dL_dy) |
| 134 | + for j in 1:n_p |
| 135 | + Δp = zeros(n_p); Δp[j] = 1.0 |
| 136 | + s_plus = _solve_exa(p0 .+ h .* Δp) |
| 137 | + dx_fd = (Vector(MadNLP.variable(s_plus.x)) .- x0) ./ h |
| 138 | + dy_fd = (Vector(s_plus.y) .- y0) ./ h |
| 139 | + @test isapprox(vjp.grad_p[j], dot(dL_dx, dx_fd) + dot(dL_dy, dy_fd); atol) |
| 140 | + end |
| 141 | +end |
0 commit comments