|
| 1 | +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""The restarted Halpern primal-dual hybrid gradient method.""" |
| 16 | + |
| 17 | +from jax import lax, numpy as jnp |
| 18 | +from optax import tree_utils as otu |
| 19 | + |
| 20 | + |
| 21 | +def solve_canonical( |
| 22 | + c, A, b, iters, reflect=True, restarts=True, tau=None, sigma=None |
| 23 | +): |
| 24 | + r"""Solves a linear program using the restarted Halpern primal-dual hybrid |
| 25 | + gradient (RHPDHG) method. |
| 26 | +
|
| 27 | + Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`x \geq 0`. |
| 28 | +
|
| 29 | + See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_. |
| 30 | +
|
| 31 | + Args: |
| 32 | + c: Cost vector. |
| 33 | + A: Equality constraint matrix. |
| 34 | + b: Equality constraint vector. |
| 35 | + iters: Number of iterations to run the solver for. |
| 36 | + reflect: Use reflection. See paper for details. |
| 37 | + restarts: Use restarts. See paper for details. |
| 38 | + tau: Primal step size. See paper for details. |
| 39 | + sigma: Dual step size. See paper for details. |
| 40 | +
|
| 41 | + Returns: |
| 42 | + A dictionary whose entries are as follows: |
| 43 | + - primal: The final primal solution. |
| 44 | + - dual: The final dual solution. |
| 45 | + - primal_iterates: The primal iterates. |
| 46 | + - dual_iterates: The dual iterates. |
| 47 | +
|
| 48 | + Examples: |
| 49 | + >>> from jax import numpy as jnp |
| 50 | + >>> import optax |
| 51 | + >>> c = -jnp.array([2, 1]) |
| 52 | + >>> A = jnp.zeros([0, 2]) |
| 53 | + >>> b = jnp.zeros(0) |
| 54 | + >>> G = jnp.array([[3, 1], [1, 1], [1, 4]]) |
| 55 | + >>> h = jnp.array([21, 9, 24]) |
| 56 | + >>> x = optax.linprog.rhpdhg(c, A, b, G, h, 1_000_000)['primal'] |
| 57 | + >>> print(x[0]) |
| 58 | + 5.99... |
| 59 | + >>> print(x[1]) |
| 60 | + 2.99... |
| 61 | +
|
| 62 | + References: |
| 63 | + Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming |
| 64 | + <https://arxiv.org/abs/2407.16144>`_, 2024 |
| 65 | + Haihao Lu, Zedong Peng, Jinwen Yang, `MPAX: Mathematical Programming in JAX |
| 66 | + <https://arxiv.org/abs/2412.09734>`_, 2024 |
| 67 | + """ |
| 68 | + |
| 69 | + if tau is None or sigma is None: |
| 70 | + A_norm = jnp.linalg.norm(A, axis=(0, 1), ord=2) |
| 71 | + if tau is None: |
| 72 | + tau = 1 / (2 * A_norm) |
| 73 | + if sigma is None: |
| 74 | + sigma = 1 / (2 * A_norm) |
| 75 | + |
| 76 | + def T(z): |
| 77 | + # primal dual hybrid gradient (PDHG) |
| 78 | + x, y = z |
| 79 | + xn = x + tau * (y @ A - c) |
| 80 | + xn = xn.clip(min=0) |
| 81 | + yn = y + sigma * (b - A @ (2 * xn - x)) |
| 82 | + return xn, yn |
| 83 | + |
| 84 | + def H(z, k, z0): |
| 85 | + # Halpern PDHG |
| 86 | + Tz = T(z) |
| 87 | + if reflect: |
| 88 | + zc = otu.tree_sub(otu.tree_scalar_mul(2, Tz), z) |
| 89 | + else: |
| 90 | + zc = Tz |
| 91 | + kp2 = k + 2 |
| 92 | + zn = otu.tree_add( |
| 93 | + otu.tree_scalar_mul((k + 1) / kp2, zc), |
| 94 | + otu.tree_scalar_mul(1 / kp2, z0), |
| 95 | + ) |
| 96 | + return zn, Tz |
| 97 | + |
| 98 | + def update(carry, _): |
| 99 | + z, k, z0, d0 = carry |
| 100 | + zn, Tz = H(z, k, z0) |
| 101 | + |
| 102 | + if restarts: |
| 103 | + d = otu.tree_l2_norm(otu.tree_sub(z, Tz), squared=True) |
| 104 | + restart = d <= d0 * jnp.exp(-2) |
| 105 | + new_carry = otu.tree_where( |
| 106 | + restart, |
| 107 | + (zn, 0, zn, d), |
| 108 | + (zn, k + 1, z0, d0), |
| 109 | + ) |
| 110 | + else: |
| 111 | + new_carry = zn, k + 1, z0, d0 |
| 112 | + |
| 113 | + return new_carry, z |
| 114 | + |
| 115 | + def run(): |
| 116 | + m, n = A.shape |
| 117 | + x = jnp.zeros(n) |
| 118 | + y = jnp.zeros(m) |
| 119 | + z0 = x, y |
| 120 | + d0 = otu.tree_l2_norm(otu.tree_sub(z0, T(z0)), squared=True) |
| 121 | + (z, _, _, _), zs = lax.scan(update, (z0, 0, z0, d0), length=iters) |
| 122 | + x, y = z |
| 123 | + xs, ys = zs |
| 124 | + return { |
| 125 | + "primal": x, |
| 126 | + "dual": y, |
| 127 | + "primal_iterates": xs, |
| 128 | + "dual_iterates": ys, |
| 129 | + } |
| 130 | + |
| 131 | + return run() |
| 132 | + |
| 133 | + |
| 134 | +def general_to_canonical(c, A, b, G, h): |
| 135 | + """Converts a linear program from general form to canonical form. |
| 136 | +
|
| 137 | + The solution to the new linear program will consist of the concatenation of |
| 138 | + - the positive part of x |
| 139 | + - the negative part of x |
| 140 | + - slacks |
| 141 | +
|
| 142 | + That is, we go from |
| 143 | +
|
| 144 | + Minimize c · x subject to |
| 145 | + A x = b |
| 146 | + G x ≤ h |
| 147 | +
|
| 148 | + to |
| 149 | +
|
| 150 | + Minimize c · (x⁺ - x⁻) subject to |
| 151 | + A (x⁺ - x⁻) = b |
| 152 | + G (x⁺ - x⁻) + s = h |
| 153 | + x⁺, x⁻, s ≥ 0 |
| 154 | +
|
| 155 | + Args: |
| 156 | + c: Cost vector. |
| 157 | + A: Equality constraint matrix. |
| 158 | + b: Equality constraint vector. |
| 159 | + G: Inequality constraint matrix. |
| 160 | + h: Inequality constraint vector. |
| 161 | +
|
| 162 | + Returns: |
| 163 | + A triple (c', A', b') representing the corresponding canonical form. |
| 164 | + """ |
| 165 | + c_can = jnp.concatenate([c, -c, jnp.zeros(h.size)]) |
| 166 | + G_ = jnp.concatenate([G, -G, jnp.eye(h.size)], 1) |
| 167 | + A_ = jnp.concatenate([A, -A, jnp.zeros([b.size, h.size])], 1) |
| 168 | + A_can = jnp.concatenate([A_, G_], 0) |
| 169 | + b_can = jnp.concatenate([b, h]) |
| 170 | + return c_can, A_can, b_can |
| 171 | + |
| 172 | + |
| 173 | +def solve_general( |
| 174 | + c, A, b, G, h, iters, reflect=True, restarts=True, tau=None, sigma=None |
| 175 | +): |
| 176 | + r"""Solves a linear program using the restarted Halpern primal-dual hybrid |
| 177 | + gradient (RHPDHG) method. |
| 178 | +
|
| 179 | + Minimizes :math:`c \cdot x` subject to :math:`A x = b` and :math:`G x \leq h`. |
| 180 | +
|
| 181 | + See also `MPAX <https://github.com/MIT-Lu-Lab/MPAX>`_. |
| 182 | +
|
| 183 | + Args: |
| 184 | + c: Cost vector. |
| 185 | + A: Equality constraint matrix. |
| 186 | + b: Equality constraint vector. |
| 187 | + G: Inequality constraint matrix. |
| 188 | + h: Inequality constraint vector. |
| 189 | + iters: Number of iterations to run the solver for. |
| 190 | + reflect: Use reflection. See paper for details. |
| 191 | + restarts: Use restarts. See paper for details. |
| 192 | + tau: Primal step size. See paper for details. |
| 193 | + sigma: Dual step size. See paper for details. |
| 194 | +
|
| 195 | + Returns: |
| 196 | + A dictionary whose entries are as follows: |
| 197 | + - primal: The final primal solution. |
| 198 | + - slacks: The final primal slack values. |
| 199 | + - canonical_result: The result for the canonical program that was used |
| 200 | + internally to find this solution. See paper for details. |
| 201 | +
|
| 202 | + References: |
| 203 | + Haihao Lu, Jinwen Yang, `Restarted Halpern PDHG for Linear Programming |
| 204 | + <https://arxiv.org/abs/2407.16144>`_, 2024 |
| 205 | + Haihao Lu, Zedong Peng, Jinwen Yang, `MPAX: Mathematical Programming in JAX |
| 206 | + <https://arxiv.org/abs/2412.09734>`_, 2024 |
| 207 | + """ |
| 208 | + canonical = general_to_canonical(c, A, b, G, h) |
| 209 | + result = solve_canonical(*canonical, iters, reflect, restarts, tau, sigma) |
| 210 | + x_pos, x_neg, slacks = jnp.split(result["primal"], [c.size, c.size * 2]) |
| 211 | + return { |
| 212 | + "primal": x_pos - x_neg, |
| 213 | + "slacks": slacks, |
| 214 | + "canonical_result": result, |
| 215 | + } |
0 commit comments