Skip to content

Commit 72183f5

Browse files
committed
Add linear program solver based on the restarted Halpern primal-dual hybrid gradient (rHPDHG) algorithm.
1 parent 3d8c391 commit 72183f5

File tree

12 files changed

+595
-3
lines changed

12 files changed

+595
-3
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ disable=R,
129129
wrong-import-order,
130130
xrange-builtin,
131131
zip-builtin-not-iterating,
132+
invalid-name,
132133

133134

134135
[REPORTS]

docs/api/linprog.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Linear programming
2+
==================
3+
4+
.. currentmodule:: optax.linprog
5+
6+
.. autosummary::
7+
rhpdhg
8+
9+
10+
Restarted Halpern primal-dual hybrid gradient method
11+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
12+
.. autofunction:: rhpdhg

docs/gallery.rst

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@
209209
.. only:: html
210210

211211
.. image:: /images/examples/linear_assignment_problem.png
212-
:alt:
212+
:alt: Linear assignment problem.
213213

214214
:doc:`_collections/examples/linear_assignment_problem`
215215

@@ -219,6 +219,23 @@
219219
</div>
220220

221221

222+
.. raw:: html
223+
224+
<div class="sphx-glr-thumbcontainer" tooltip="Linear programming.">
225+
226+
.. only:: html
227+
228+
.. image:: /images/examples/linear_programming.png
229+
:alt: Linear programming.
230+
231+
:doc:`_collections/examples/linear_programming`
232+
233+
.. raw:: html
234+
235+
<div class="sphx-glr-thumbnail-title">Linear programming.</div>
236+
</div>
237+
238+
222239
.. raw:: html
223240

224241
</div>
76.8 KB
Loading

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ for instructions on installing JAX.
5454
:caption: 📖 Reference
5555
:maxdepth: 2
5656

57+
api/linprog
5758
api/assignment
5859
api/optimizers
5960
api/transformations

examples/linear_programming.ipynb

Lines changed: 229 additions & 0 deletions
Large diffs are not rendered by default.

optax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from optax import assignment
2121
from optax import contrib
22+
from optax import linprog
2223
from optax import losses
2324
from optax import monte_carlo
2425
from optax import perturbations
@@ -364,6 +365,7 @@
364365
"lion",
365366
"linear_onecycle_schedule",
366367
"linear_schedule",
368+
"linprog",
367369
"log_cosh",
368370
"lookahead",
369371
"LookaheadParams",

optax/_src/alias.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2482,7 +2482,7 @@ def lbfgs(
24822482
... )
24832483
... params = optax.apply_updates(params, updates)
24842484
... print('Objective function: ', f(params))
2485-
Objective function: 7.5166864
2485+
Objective function: 7.516686...
24862486
Objective function: 7.460699e-14
24872487
Objective function: 2.6505726e-28
24882488
Objective function: 0.0

optax/linprog/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""The linear programming sub-package."""
16+
17+
# pylint:disable=g-importing-member
18+
19+
from optax.linprog._rhpdhg import solve_general as rhpdhg

optax/linprog/_rhpdhg.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""The restarted Halpern primal-dual hybrid gradient method."""
16+
17+
from jax import lax, numpy as jnp
18+
from optax import tree_utils as otu
19+
20+
21+
def solve_canonical(
22+
c, A, b, iters, reflect=True, restarts=True, tau=None, sigma=None
23+
):
24+
r"""Solves a linear program using the restarted Halpern primal-dual hybrid
25+
gradient (RHPDHG) method.
26+
27+
Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`x \geq 0`.
28+
29+
See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_.
30+
31+
Args:
32+
c: Cost vector.
33+
A: Equality constraint matrix.
34+
b: Equality constraint vector.
35+
iters: Number of iterations to run the solver for.
36+
reflect: Use reflection. See paper for details.
37+
restarts: Use restarts. See paper for details.
38+
tau: Primal step size. See paper for details.
39+
sigma: Dual step size. See paper for details.
40+
41+
Returns:
42+
A dictionary whose entries are as follows:
43+
- primal: The final primal solution.
44+
- dual: The final dual solution.
45+
- primal_iterates: The primal iterates.
46+
- dual_iterates: The dual iterates.
47+
48+
Examples:
49+
>>> from jax import numpy as jnp
50+
>>> import optax
51+
>>> c = -jnp.array([2, 1])
52+
>>> A = jnp.zeros([0, 2])
53+
>>> b = jnp.zeros(0)
54+
>>> G = jnp.array([[3, 1], [1, 1], [1, 4]])
55+
>>> h = jnp.array([21, 9, 24])
56+
>>> x = optax.linprog.rhpdhg(c, A, b, G, h, 1_000_000)['primal']
57+
>>> print(x[0])
58+
5.99...
59+
>>> print(x[1])
60+
2.99...
61+
62+
References:
63+
Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming
64+
<https://arxiv.org/abs/2407.16144>`_, 2024
65+
Haihao Lu, Zedong Peng, Jinwen Yang, `MPAX: Mathematical Programming in JAX
66+
<https://arxiv.org/abs/2412.09734>`_, 2024
67+
"""
68+
69+
if tau is None or sigma is None:
70+
A_norm = jnp.linalg.norm(A, axis=(0, 1), ord=2)
71+
if tau is None:
72+
tau = 1 / (2 * A_norm)
73+
if sigma is None:
74+
sigma = 1 / (2 * A_norm)
75+
76+
def T(z):
77+
# primal dual hybrid gradient (PDHG)
78+
x, y = z
79+
xn = x + tau * (y @ A - c)
80+
xn = xn.clip(min=0)
81+
yn = y + sigma * (b - A @ (2 * xn - x))
82+
return xn, yn
83+
84+
def H(z, k, z0):
85+
# Halpern PDHG
86+
Tz = T(z)
87+
if reflect:
88+
zc = otu.tree_sub(otu.tree_scalar_mul(2, Tz), z)
89+
else:
90+
zc = Tz
91+
kp2 = k + 2
92+
zn = otu.tree_add(
93+
otu.tree_scalar_mul((k + 1) / kp2, zc),
94+
otu.tree_scalar_mul(1 / kp2, z0),
95+
)
96+
return zn, Tz
97+
98+
def update(carry, _):
99+
z, k, z0, d0 = carry
100+
zn, Tz = H(z, k, z0)
101+
102+
if restarts:
103+
d = otu.tree_l2_norm(otu.tree_sub(z, Tz), squared=True)
104+
restart = d <= d0 * jnp.exp(-2)
105+
new_carry = otu.tree_where(
106+
restart,
107+
(zn, 0, zn, d),
108+
(zn, k + 1, z0, d0),
109+
)
110+
else:
111+
new_carry = zn, k + 1, z0, d0
112+
113+
return new_carry, z
114+
115+
def run():
116+
m, n = A.shape
117+
x = jnp.zeros(n)
118+
y = jnp.zeros(m)
119+
z0 = x, y
120+
d0 = otu.tree_l2_norm(otu.tree_sub(z0, T(z0)), squared=True)
121+
(z, _, _, _), zs = lax.scan(update, (z0, 0, z0, d0), length=iters)
122+
x, y = z
123+
xs, ys = zs
124+
return {
125+
"primal": x,
126+
"dual": y,
127+
"primal_iterates": xs,
128+
"dual_iterates": ys,
129+
}
130+
131+
return run()
132+
133+
134+
def general_to_canonical(c, A, b, G, h):
135+
"""Converts a linear program from general form to canonical form.
136+
137+
The solution to the new linear program will consist of the concatenation of
138+
- the positive part of x
139+
- the negative part of x
140+
- slacks
141+
142+
That is, we go from
143+
144+
Minimize c · x subject to
145+
A x = b
146+
G x ≤ h
147+
148+
to
149+
150+
Minimize c · (x⁺ - x⁻) subject to
151+
A (x⁺ - x⁻) = b
152+
G (x⁺ - x⁻) + s = h
153+
x⁺, x⁻, s ≥ 0
154+
155+
Args:
156+
c: Cost vector.
157+
A: Equality constraint matrix.
158+
b: Equality constraint vector.
159+
G: Inequality constraint matrix.
160+
h: Inequality constraint vector.
161+
162+
Returns:
163+
A triple (c', A', b') representing the corresponding canonical form.
164+
"""
165+
c_can = jnp.concatenate([c, -c, jnp.zeros(h.size)])
166+
G_ = jnp.concatenate([G, -G, jnp.eye(h.size)], 1)
167+
A_ = jnp.concatenate([A, -A, jnp.zeros([b.size, h.size])], 1)
168+
A_can = jnp.concatenate([A_, G_], 0)
169+
b_can = jnp.concatenate([b, h])
170+
return c_can, A_can, b_can
171+
172+
173+
def solve_general(
174+
c, A, b, G, h, iters, reflect=True, restarts=True, tau=None, sigma=None
175+
):
176+
r"""Solves a linear program using the restarted Halpern primal-dual hybrid
177+
gradient (RHPDHG) method.
178+
179+
Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`G x \leq h`.
180+
181+
See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_.
182+
183+
Args:
184+
c: Cost vector.
185+
A: Equality constraint matrix.
186+
b: Equality constraint vector.
187+
G: Inequality constraint matrix.
188+
h: Inequality constraint vector.
189+
iters: Number of iterations to run the solver for.
190+
reflect: Use reflection. See paper for details.
191+
restarts: Use restarts. See paper for details.
192+
tau: Primal step size. See paper for details.
193+
sigma: Dual step size. See paper for details.
194+
195+
Returns:
196+
A dictionary whose entries are as follows:
197+
- primal: The final primal solution.
198+
- slacks: The final primal slack values.
199+
- canonical_result: The result for the canonical program that was used
200+
internally to find this solution. See paper for details.
201+
202+
References:
203+
Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming
204+
<https://arxiv.org/abs/2407.16144>`_, 2024
205+
Haihao Lu, Zedong Peng, Jinwen Yang, `MPAX: Mathematical Programming in JAX
206+
<https://arxiv.org/abs/2412.09734>`_, 2024
207+
"""
208+
canonical = general_to_canonical(c, A, b, G, h)
209+
result = solve_canonical(*canonical, iters, reflect, restarts, tau, sigma)
210+
x_pos, x_neg, slacks = jnp.split(result["primal"], [c.size, c.size * 2])
211+
return {
212+
"primal": x_pos - x_neg,
213+
"slacks": slacks,
214+
"canonical_result": result,
215+
}

0 commit comments

Comments
 (0)