Skip to content

Commit 70fcddd

Browse files
Merge pull request #936 from yebai/patch-1
Adding Mooncake to the AD list.
2 parents a52b97d + cd0da39 commit 70fcddd

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ LoggingExtras = "0.4, 1"
4242
Lux = "1.12.4"
4343
MLUtils = "0.4"
4444
ModelingToolkit = "10"
45+
Mooncake = "0.4.138"
4546
Optim = ">= 1.4.1"
4647
OptimizationBase = "2"
4748
OptimizationMOI = "0.5"
@@ -99,9 +100,10 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
99100
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
100101
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
101102
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
103+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
102104

103105
[targets]
104106
test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff",
105107
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
106108
"OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays", "SparseDiffTools",
107-
"Symbolics", "Test", "Tracker", "Zygote"]
109+
"Symbolics", "Test", "Tracker", "Zygote", "Mooncake"]

docs/src/API/ad.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The choices for the auto-AD fill-ins with quick descriptions are:
99
- `AutoFiniteDiff()`: Finite differencing, not optimal but always applicable
1010
- `AutoModelingToolkit()`: The fastest choice for large scalar optimizations
1111
- `AutoEnzyme()`: Highly performant AD choice for type stable and optimized code
12+
- `AutoMooncake()`: Like Zygote and ReverseDiff, but supports GPU and mutating code
1213

1314
## Automatic Differentiation Choice API
1415

@@ -22,4 +23,5 @@ OptimizationBase.AutoZygote
2223
OptimizationBase.AutoTracker
2324
OptimizationBase.AutoModelingToolkit
2425
OptimizationBase.AutoEnzyme
26+
OptimizationBase.AutoMooncake
2527
```

test/ADtests.jl

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Optimization, OptimizationOptimJL, OptimizationMOI, Ipopt, Test
2-
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
2+
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker, Mooncake
33
using Enzyme, Random
44

55
x0 = zeros(2)
@@ -35,7 +35,7 @@ end
3535
@testset "No constraint" begin
3636
for adtype in [AutoEnzyme(), AutoForwardDiff(), AutoZygote(), AutoReverseDiff(),
3737
AutoFiniteDiff(), AutoModelingToolkit(), AutoSparseForwardDiff(),
38-
AutoSparseReverseDiff(), AutoSparse(AutoZygote()), AutoModelingToolkit(true, true)]
38+
AutoSparseReverseDiff(), AutoSparse(AutoZygote()), AutoModelingToolkit(true, true), AutoMooncake()]
3939
optf = OptimizationFunction(rosenbrock, adtype)
4040

4141
prob = OptimizationProblem(optf, x0)
@@ -46,16 +46,22 @@ end
4646
@test sol.retcode == ReturnCode.Success
4747
end
4848

49-
sol = solve(prob, Optim.Newton())
50-
@test 10 * sol.objective < l1
51-
if adtype != AutoFiniteDiff()
52-
@test sol.retcode == ReturnCode.Success
49+
# `Newton` requires Hession, which Mooncake doesn't support at the moment.
50+
if adtype != AutoMooncake()
51+
sol = solve(prob, Optim.Newton())
52+
@test 10 * sol.objective < l1
53+
if adtype != AutoFiniteDiff()
54+
@test sol.retcode == ReturnCode.Success
55+
end
5356
end
5457

55-
sol = solve(prob, Optim.KrylovTrustRegion())
56-
@test 10 * sol.objective < l1
57-
if adtype != AutoFiniteDiff()
58-
@test sol.retcode == ReturnCode.Success
58+
# Requires Hession, which Mooncake doesn't support at the moment.
59+
if adtype != AutoMooncake()
60+
sol = solve(prob, Optim.KrylovTrustRegion())
61+
@test 10 * sol.objective < l1
62+
if adtype != AutoFiniteDiff()
63+
@test sol.retcode == ReturnCode.Success
64+
end
5965
end
6066

6167
sol = solve(prob, Optimization.LBFGS(), maxiters = 1000)
@@ -67,7 +73,7 @@ end
6773
@testset "One constraint" begin
6874
for adtype in [AutoEnzyme(), AutoForwardDiff(), AutoZygote(), AutoReverseDiff(),
6975
AutoFiniteDiff(), AutoModelingToolkit(), AutoSparseForwardDiff(),
70-
AutoSparseReverseDiff(), AutoSparse(AutoZygote()), AutoModelingToolkit(true, true)]
76+
AutoSparseReverseDiff(), AutoSparse(AutoZygote()), AutoModelingToolkit(true, true), AutoMooncake()]
7177
cons = (res, x, p) -> (res[1] = x[1]^2 + x[2]^2 - 1.0; return nothing)
7278
optf = OptimizationFunction(rosenbrock, adtype, cons = cons)
7379

@@ -77,15 +83,18 @@ end
7783
sol = solve(prob, Optimization.LBFGS(), maxiters = 1000)
7884
@test 10 * sol.objective < l1
7985

80-
sol = solve(prob, Ipopt.Optimizer(), max_iter = 1000; print_level = 0)
81-
@test 10 * sol.objective < l1
86+
# Requires Hession, which Mooncake doesn't support at the moment.
87+
if adtype != AutoMooncake()
88+
sol = solve(prob, Ipopt.Optimizer(), max_iter = 1000; print_level = 0)
89+
@test 10 * sol.objective < l1
90+
end
8291
end
8392
end
8493

8594
@testset "Two constraints" begin
8695
for adtype in [AutoForwardDiff(), AutoZygote(), AutoReverseDiff(),
8796
AutoFiniteDiff(), AutoModelingToolkit(), AutoSparseForwardDiff(),
88-
AutoSparseReverseDiff(), AutoSparse(AutoZygote()), AutoModelingToolkit(true, true)]
97+
AutoSparseReverseDiff(), AutoSparse(AutoZygote()), AutoModelingToolkit(true, true), AutoMooncake()]
8998
function con2_c(res, x, p)
9099
res[1] = x[1]^2 + x[2]^2
91100
res[2] = x[2] * sin(x[1]) - x[1]
@@ -99,7 +108,10 @@ end
99108
sol = solve(prob, Optimization.LBFGS(), maxiters = 1000)
100109
@test 10 * sol.objective < l1
101110

102-
sol = solve(prob, Ipopt.Optimizer(), max_iter = 1000; print_level = 0)
103-
@test 10 * sol.objective < l1
111+
# Requires Hession, which Mooncake doesn't support at the moment.
112+
if adtype != AutoMooncake()
113+
sol = solve(prob, Ipopt.Optimizer(), max_iter = 1000; print_level = 0)
114+
@test 10 * sol.objective < l1
115+
end
104116
end
105117
end

0 commit comments

Comments
 (0)