Skip to content

Commit 4814b13

Browse files
committed
Allow constraints to be slack
1 parent 7374277 commit 4814b13

File tree

1 file changed

+44
-19
lines changed

1 file changed

+44
-19
lines changed

src/pyvmcon/vmcon.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def solve_qsp(
264264
lbs: VectorType | None,
265265
ubs: VectorType | None,
266266
options: dict[str, Any],
267+
*,
268+
allow_slack_constraints: bool = False,
267269
) -> tuple[VectorType, VectorType, VectorType]:
268270
"""Solves the quadratic programming problem.
269271
@@ -313,17 +315,41 @@ def solve_qsp(
313315
ubs - x if ubs is not None else None,
314316
],
315317
)
316-
problem_statement = cp.Minimize(
317-
result.f + (0.5 * cp.quad_form(delta, B)) + (delta.T @ result.df),
318+
319+
num_constraints = problem.num_equality + problem.num_inequality
320+
ksi = cp.Variable(num_constraints or 1, bounds=[0.0, 1.0])
321+
322+
minimise_expression = (
323+
result.f + (0.5 * cp.quad_form(delta, B)) + (delta.T @ result.df)
318324
)
325+
if allow_slack_constraints:
326+
minimise_expression -= cp.sum(ksi)
319327

320328
constraints = []
321-
if problem.has_inequality:
322-
constraints.append((result.die @ delta) + result.ie >= 0)
323-
if problem.has_equality:
324-
constraints.append((result.deq @ delta) + result.eq == 0)
329+
for con_idx in range(problem.num_equality):
330+
if allow_slack_constraints:
331+
constraints.append(
332+
(result.deq[con_idx, :] @ delta) + result.eq[con_idx] * ksi[con_idx]
333+
== 0
334+
)
335+
else:
336+
constraints.append(
337+
(result.deq[con_idx, :] @ delta) + result.eq[con_idx] == 0
338+
)
325339

326-
qsp = cp.Problem(problem_statement, constraints or None)
340+
for con_idx in range(problem.num_inequality):
341+
if result.ie[con_idx] <= 0 and allow_slack_constraints:
342+
constraints.append(
343+
(result.die[con_idx, :] @ delta)
344+
+ result.ie[con_idx] * ksi[problem.num_equality + con_idx]
345+
>= 0
346+
)
347+
else:
348+
constraints.append(
349+
(result.die[con_idx, :] @ delta) + result.ie[con_idx] >= 0
350+
)
351+
352+
qsp = cp.Problem(cp.Minimize(minimise_expression), constraints or None)
327353

328354
try:
329355
qsp.solve(**{"solver": cp.OSQP, **options})
@@ -335,18 +361,17 @@ def solve_qsp(
335361
error_msg = f"QSP failed to solve: {qsp.status}"
336362
raise _QspSolveException(error_msg)
337363

338-
lamda_equality = np.array([])
339-
lamda_inequality = np.array([])
340-
341-
if problem.has_inequality and problem.has_equality:
342-
lamda_inequality = qsp.constraints[0].dual_value
343-
lamda_equality = -qsp.constraints[1].dual_value
344-
345-
elif problem.has_inequality and not problem.has_equality:
346-
lamda_inequality = qsp.constraints[0].dual_value
347-
348-
elif not problem.has_inequality and problem.has_equality:
349-
lamda_equality = -qsp.constraints[0].dual_value
364+
lamda_equality = np.array(
365+
[-i.dual_value for i in qsp.constraints[: problem.num_equality]]
366+
)
367+
lamda_inequality = np.array(
368+
[
369+
i.dual_value
370+
for i in qsp.constraints[
371+
problem.num_equality : problem.num_equality + problem.num_inequality
372+
]
373+
]
374+
)
350375

351376
return delta.value, lamda_equality, lamda_inequality
352377

0 commit comments

Comments
 (0)