diff --git a/lib/OptimizationAuglag/Project.toml b/lib/OptimizationAuglag/Project.toml new file mode 100644 index 000000000..3f01a0e7d --- /dev/null +++ b/lib/OptimizationAuglag/Project.toml @@ -0,0 +1,20 @@ +name = "OptimizationAuglag" +uuid = "2ea93f80-9333-43a1-a68d-1f53b957a421" +authors = ["paramthakkar123 "] +version = "0.1.0" + +[deps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +ForwardDiff = "1.0.1" +MLUtils = "0.4.8" +Optimization = "4.4.0" +OptimizationBase = "2.10.0" +OptimizationOptimisers = "0.3.8" +Test = "1.11.0" diff --git a/src/auglag.jl b/lib/OptimizationAuglag/src/OptimizationAuglag.jl similarity index 94% rename from src/auglag.jl rename to lib/OptimizationAuglag/src/OptimizationAuglag.jl index a9fd3981f..43ae86ad2 100644 --- a/src/auglag.jl +++ b/lib/OptimizationAuglag/src/OptimizationAuglag.jl @@ -1,3 +1,9 @@ +module OptimizationAuglag + +using Optimization +using OptimizationBase.SciMLBase: OptimizationProblem, OptimizationFunction, OptimizationStats +using OptimizationBase.LinearAlgebra: norm + @kwdef struct AugLag inner::Any τ = 0.5 @@ -15,7 +21,7 @@ SciMLBase.requiresgradient(::AugLag) = true SciMLBase.allowsconstraints(::AugLag) = true SciMLBase.requiresconsjac(::AugLag) = true -function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag; +function __map_optimizer_args(cache::OptimizationBase.OptimizationCache, opt::AugLag; callback = nothing, maxiters::Union{Number, Nothing} = nothing, maxtime::Union{Number, Nothing} = nothing, @@ -105,7 +111,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache.f.cons(cons_tmp, θ) cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] - opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = p) + opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) if cache.callback(opt_state, x...) error("Optimization halted by callback.") end @@ -171,10 +177,12 @@ function SciMLBase.__solve(cache::OptimizationCache{ break end end - stats = Optimization.OptimizationStats(; iterations = maxiters, + stats = OptimizationStats(; iterations = maxiters, time = 0.0, fevals = maxiters, gevals = maxiters) return SciMLBase.build_solution( cache, cache.opt, θ, x, stats = stats, retcode = opt_ret) end end + +end diff --git a/lib/OptimizationAuglag/test/runtests.jl b/lib/OptimizationAuglag/test/runtests.jl new file mode 100644 index 000000000..60f994265 --- /dev/null +++ b/lib/OptimizationAuglag/test/runtests.jl @@ -0,0 +1,36 @@ +using OptimizationBase +using MLUtils +using OptimizationOptimisers +using OptimizationAuglag +using ForwardDiff +using OptimizationBase: OptimizationCache +using OptimizationBase.SciMLBase: OptimizationFunction +using Test + +@testset "OptimizationAuglag.jl" begin + x0 = (-pi):0.001:pi + y0 = sin.(x0) + data = MLUtils.DataLoader((x0, y0), batchsize = 126) + + function loss(coeffs, data) + ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])] + return sum(abs2, ypred .- data[2]) + end + + function cons1(res, coeffs, p = nothing) + res[1] = coeffs[1] * coeffs[5] - 1 + return nothing + end + + optf = OptimizationFunction(loss, OptimizationBase.AutoSparseForwardDiff(), cons = cons1) + callback = (st, l) -> (@show l; return false) + + initpars = rand(5) + l0 = optf(initpars, (x0, y0)) + + prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1], + lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) + opt = solve( + prob, OptimizationAuglag.AugLag(; inner = Adam()), maxiters = 10000, callback = callback) + @test opt.objective < l0 +end \ No newline at end of file diff --git a/src/Optimization.jl b/src/Optimization.jl index 8d0257dd1..4cfeead6e 100644 --- a/src/Optimization.jl +++ b/src/Optimization.jl @@ -24,7 +24,6 @@ include("utils.jl") include("state.jl") include("lbfgsb.jl") include("sophia.jl") -include("auglag.jl") export solve diff --git a/test/native.jl b/test/native.jl index 0c6c0f6e5..f7385fd0d 100644 --- a/test/native.jl +++ b/test/native.jl @@ -51,12 +51,6 @@ prob = OptimizationProblem(optf, initpars, (x0, y0), lcons = [-Inf], ucons = [0. opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback) @test opt1.objective < l0 -prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1], - lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) -opt = solve( - prob, Optimization.AugLag(; inner = Adam()), maxiters = 10000, callback = callback) -@test opt.objective < l0 - optf1 = OptimizationFunction(loss, AutoSparseForwardDiff()) prob1 = OptimizationProblem(optf1, rand(5), data) sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback)