Skip to content

Commit 6640e38

Browse files
committed
ExaModels support
1 parent 2cbcd33 commit 6640e38

File tree

5 files changed

+53
-4
lines changed

5 files changed

+53
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ MadDiff implements forward and reverse mode implicit differentiation for MadSuit
1212
## NLPModels interface
1313

1414
> [!NOTE]
15-
> The [NLPModels](https://github.com/JuliaSmoothOptimizers/NLPModels.jl) interface requires that your `AbstractNLPModel` implementation includes the [`ParametricNLPModels`](https://github.com/klamike/ParametricNLPModels.jl/tree/mk/pnlpm) API. Currently, this is automated only for the case when using MadNLP through JuMP, but support for ExaModels, ADNLPModels, and NLPModelsJuMP is planned.
15+
> The [NLPModels](https://github.com/JuliaSmoothOptimizers/NLPModels.jl) interface requires that your `AbstractNLPModel` implementation includes the [`ParametricNLPModels`](https://github.com/klamike/ParametricNLPModels.jl/tree/mk/pnlpm) API. Currently, this is automated only for the case when using MadNLP through JuMP or when using ExaModels; support for other solvers and frameworks is planned.
1616
1717

1818
```julia

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ MadDiff implements forward and reverse mode implicit differentiation for MadSuit
1010
1111
## NLPModels interface
1212

13-
> The [NLPModels](https://github.com/JuliaSmoothOptimizers/NLPModels.jl) interface requires that your `AbstractNLPModel` implementation includes the [`ParametricNLPModels`](https://github.com/klamike/ParametricNLPModels.jl/tree/mk/pnlpm) API. Currently, this is automated only for the case when using MadNLP through JuMP, but support for ExaModels, ADNLPModels, and NLPModelsJuMP is planned.
13+
> The [NLPModels](https://github.com/JuliaSmoothOptimizers/NLPModels.jl) interface requires that your `AbstractNLPModel` implementation includes the [`ParametricNLPModels`](https://github.com/klamike/ParametricNLPModels.jl/tree/mk/pnlpm) API. Currently, this is automated only for the case when using MadNLP through JuMP or when using ExaModels; support for other solvers and frameworks is planned.
1414
1515

1616
```julia

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2222

2323
[sources]
2424
DiffOpt = {rev = "mk/allow_obj_and_sol", url = "https://github.com/klamike/DiffOpt.jl.git"}
25+
ExaModels = {rev = "mk/param_ad", url = "https://github.com/klamike/ExaModels.jl"}
2526
HybridKKT = {rev = "mk/latest", url = "https://github.com/klamike/HybridKKT.jl.git"}
2627
MadDiff = {path = ".."}
2728
MadIPM = {rev = "mk/flip", url = "https://github.com/klamike/MadIPM.jl.git"}

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test, Random, LinearAlgebra
22
using MadDiff
33
using MadNLP, MadIPM, MadNCL, HybridKKT
4-
using NLPModels, CUDA, MadNLPGPU, MadNLPTests, QuadraticModels
4+
using NLPModels, CUDA, MadNLPGPU, MadNLPTests, QuadraticModels, ExaModels
55
using JuMP, DiffOpt, MathOptInterface
66
const MOI = MathOptInterface
77

test/test_jacobian.jl

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NLPModels
1+
using NLPModels, LinearAlgebra, ExaModels
22

33
@testset "jac/jact/jvp/vjp" begin
44
function _check_consistency(sens; atol = 1e-8)
@@ -91,3 +91,51 @@ using NLPModels
9191
optimize!(model_rect)
9292
_check_consistency(MadDiff.MadDiffSolver(unsafe_backend(model_rect).inner.solver); atol = 1e-8)
9393
end
94+
95+
@testset "ExaModels JVP/VJP vs FiniteDiff" begin
96+
p0 = [1.0, 3.0]
97+
h = 1e-5
98+
atol = sqrt(h)
99+
100+
function _make_exa(p_vals)
101+
c = ExaCore()
102+
p = ExaModels.parameter(c, p_vals)
103+
x = ExaModels.variable(c, 2)
104+
ExaModels.objective(c, x[1]^2 + x[2]^2 + p[1] * x[1])
105+
ExaModels.constraint(c, x[1] + x[2] - p[2])
106+
return ExaModel(c)
107+
end
108+
109+
function _solve_exa(p_vals)
110+
m = _make_exa(p_vals)
111+
solver = MadNLP.MadNLPSolver(m; print_level = MadNLP.ERROR)
112+
MadNLP.solve!(solver)
113+
return solver
114+
end
115+
116+
solver0 = _solve_exa(p0)
117+
sens = MadDiffSolver(solver0)
118+
x0 = Vector(MadNLP.variable(solver0.x))
119+
y0 = Vector(solver0.y)
120+
n_p = sens.n_p
121+
122+
for j in 1:n_p
123+
Δp = zeros(n_p); Δp[j] = 1.0
124+
jvp = MadDiff.jacobian_vector_product!(sens, Δp)
125+
s_plus = _solve_exa(p0 .+ h .* Δp)
126+
@test isapprox(jvp.dx, (Vector(MadNLP.variable(s_plus.x)) .- x0) ./ h; atol)
127+
@test isapprox(jvp.dy, (Vector(s_plus.y) .- y0) ./ h; atol)
128+
end
129+
130+
rng = MersenneTwister(42)
131+
dL_dx = randn(rng, length(x0))
132+
dL_dy = randn(rng, length(y0))
133+
vjp = MadDiff.vector_jacobian_product!(sens; dL_dx, dL_dy)
134+
for j in 1:n_p
135+
Δp = zeros(n_p); Δp[j] = 1.0
136+
s_plus = _solve_exa(p0 .+ h .* Δp)
137+
dx_fd = (Vector(MadNLP.variable(s_plus.x)) .- x0) ./ h
138+
dy_fd = (Vector(s_plus.y) .- y0) ./ h
139+
@test isapprox(vjp.grad_p[j], dot(dL_dx, dx_fd) + dot(dL_dy, dy_fd); atol)
140+
end
141+
end

0 commit comments

Comments
 (0)