Skip to content

Commit 84365e1

Browse files
kiijeonghooMizux
authored andcommitted
linear_solver: implement logcallback for HiGHS
1 parent 53e0ed8 commit 84365e1

File tree

6 files changed

+93
-14
lines changed

6 files changed

+93
-14
lines changed

ortools/linear_solver/java/ModelBuilderTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,25 @@ public void importFromLpString() {
152152
assertThat(model.varFromIndex(0).getUpperBound()).isEqualTo(42.0);
153153
assertThat(model.varFromIndex(0).getName()).isEqualTo("x");
154154
}
155+
156+
@Test
157+
public void highsLogCallback_receivesLogsWhenEnabled() {
158+
ModelBuilder model = new ModelBuilder();
159+
double infinity = Double.POSITIVE_INFINITY;
160+
Variable x = model.newNumVar(0.0, infinity, "x");
161+
Variable y = model.newNumVar(0.0, infinity, "y");
162+
model.addLessOrEqual(LinearExpr.sum(new Variable[] {x, y}), 10.0);
163+
model.maximize(LinearExpr.newBuilder().addTerm(x, 1.0).addTerm(y, 2.0));
164+
165+
ModelSolver solver = new ModelSolver("highs");
166+
if (!solver.solverIsSupported()) {
167+
return;
168+
}
169+
StringBuilder log = new StringBuilder();
170+
solver.setLogCallback(msg -> log.append(msg));
171+
solver.enableOutput(true);
172+
assertThat(solver.solve(model)).isEqualTo(SolveStatus.OPTIMAL);
173+
assertThat(log.length()).isGreaterThan(0);
174+
assertThat(log.toString()).contains("Model");
175+
}
155176
}

ortools/linear_solver/proto_solver/highs_proto_solver.cc

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,25 @@ namespace operations_research {
4646
absl::Status SetSolverSpecificParameters(const std::string& parameters,
4747
Highs& highs);
4848

49+
namespace {
50+
51+
// Adapter from HiGHS unified callback (setCallback) to OR-Tools
52+
// TODO: Only forwards kCallbackLogging now
53+
void HighsCallbackAdapter(int callback_type, const std::string& message,
54+
const HighsCallbackOutput*, HighsCallbackInput*,
55+
void* user_data) {
56+
if (callback_type == kCallbackLogging && user_data != nullptr) {
57+
const auto* cb =
58+
static_cast<const std::function<void(const std::string&)>*>(user_data);
59+
if (*cb) (*cb)(message);
60+
}
61+
}
62+
63+
}
64+
4965
absl::StatusOr<MPSolutionResponse> HighsSolveProto(
50-
LazyMutableCopy<MPModelRequest> request, HighsSolveInfo* solve_info) {
66+
LazyMutableCopy<MPModelRequest> request, HighsSolveInfo* solve_info,
67+
const std::function<void(const std::string&)>* logging_callback) {
5168
MPSolutionResponse response;
5269
const std::optional<LazyMutableCopy<MPModelProto>> optional_model =
5370
GetMPModelOrPopulateResponse(request, &response);
@@ -81,6 +98,29 @@ absl::StatusOr<MPSolutionResponse> HighsSolveProto(
8198
}
8299
}
83100

101+
// Logging.
102+
if (request->enable_internal_solver_output()) {
103+
highs.setOptionValue("log_to_console", true);
104+
highs.setOptionValue("output_flag", true);
105+
if (logging_callback != nullptr && *logging_callback) {
106+
if (highs.setCallback(HighsCallbackAdapter,
107+
const_cast<std::function<void(const std::string&)>*>(
108+
logging_callback)) != HighsStatus::kOk) {
109+
response.set_status(MPSOLVER_ABNORMAL);
110+
response.set_status_str("HiGHS setCallback failed");
111+
return response;
112+
}
113+
if (highs.startCallback(kCallbackLogging) != HighsStatus::kOk) {
114+
response.set_status(MPSOLVER_ABNORMAL);
115+
response.set_status_str("HiGHS startCallback(kCallbackLogging) failed");
116+
return response;
117+
}
118+
}
119+
} else {
120+
highs.setOptionValue("log_to_console", false);
121+
highs.setOptionValue("output_flag", false);
122+
}
123+
84124
const int variable_size = model.variable_size();
85125
bool has_integer_variables = false;
86126
{
@@ -211,19 +251,16 @@ absl::StatusOr<MPSolutionResponse> HighsSolveProto(
211251
highs.changeObjectiveOffset(offset);
212252
}
213253

214-
// Logging.
215-
if (request->enable_internal_solver_output()) {
216-
highs.setOptionValue("log_to_console", true);
217-
highs.setOptionValue("output_flag", true);
218-
} else {
219-
highs.setOptionValue("log_to_console", false);
220-
highs.setOptionValue("output_flag", false);
221-
}
222-
223254
const absl::Time time_before = absl::Now();
224255
UserTimer user_timer;
225256
user_timer.Start();
226257
const HighsStatus run_status = highs.run();
258+
259+
// Unregister log callback.
260+
if (logging_callback != nullptr && *logging_callback) {
261+
highs.stopCallback(kCallbackLogging);
262+
highs.setCallback(nullptr, nullptr);
263+
}
227264
VLOG(2) << "run_status: " << highsStatusToString(run_status);
228265
if (run_status == HighsStatus::kError) {
229266
response.set_status(MPSOLVER_NOT_SOLVED);

ortools/linear_solver/proto_solver/highs_proto_solver.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#define ORTOOLS_LINEAR_SOLVER_PROTO_SOLVER_HIGHS_PROTO_SOLVER_H_
1616

1717
#include <cstdint>
18+
#include <functional>
19+
#include <string>
1820

1921
#include "absl/status/statusor.h"
2022
#include "ortools/linear_solver/linear_solver.pb.h"
@@ -28,10 +30,12 @@ struct HighsSolveInfo {
2830
};
2931

3032
// Solve the input MIP model with the HIGHS solver and fills `solve_info` if
31-
// provided.
33+
// provided. When `logging_callback` is non-null and points to a callable
34+
// function, solver log messages are sent to it instead of the console.
3235
absl::StatusOr<MPSolutionResponse> HighsSolveProto(
3336
LazyMutableCopy<MPModelRequest> request,
34-
HighsSolveInfo* solve_info = nullptr);
37+
HighsSolveInfo* solve_info = nullptr,
38+
const std::function<void(const std::string&)>* logging_callback = nullptr);
3539

3640
} // namespace operations_research
3741

ortools/linear_solver/python/model_builder_helper.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "ortools/linear_solver/model_exporter.h"
4848
#include "pybind11/cast.h"
4949
#include "pybind11/eigen.h"
50+
#include "pybind11/functional.h"
5051
#include "pybind11/numpy.h"
5152
#include "pybind11/pybind11.h"
5253
#include "pybind11/pytypes.h"

ortools/linear_solver/python/model_builder_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,23 @@ def test_import_from_lp_file(self):
205205
self.assertEqual(42, model.var_from_index(0).upper_bound)
206206
self.assertEqual("x", model.var_from_index(0).name)
207207

208+
def test_highs_log_callback_receives_logs_when_enabled(self):
209+
model = mb.Model()
210+
x = model.new_num_var(0.0, math.inf, "x")
211+
y = model.new_num_var(0.0, math.inf, "y")
212+
model.add(x + y <= 10.0)
213+
model.maximize(1.0 * x + 2.0 * y)
214+
215+
solver = mb.Solver("highs")
216+
if not solver.solver_is_supported():
217+
return
218+
log_lines = []
219+
solver.log_callback = log_lines.append
220+
solver.enable_output(True)
221+
self.assertEqual(solver.solve(model), mb.SolveStatus.OPTIMAL)
222+
self.assertGreater(len(log_lines), 0, "Log callback should receive output")
223+
self.assertIn("Model", "".join(log_lines))
224+
208225
def test_class_api(self):
209226
model = mb.Model()
210227
x = model.new_int_var(0, 10, "x")

ortools/linear_solver/wrappers/model_builder_helper.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,8 @@ void ModelSolverHelper::Solve(const ModelBuilderHelper& model) {
619619
#if defined(USE_HIGHS)
620620
case MPModelRequest::HIGHS_LINEAR_PROGRAMMING: // ABSL_FALLTHROUGH_INTENDED
621621
case MPModelRequest::HIGHS_MIXED_INTEGER_PROGRAMMING: {
622-
// TODO(user): Enable log_callback support.
623622
// TODO(user): Enable interrupt_solve.
624-
const auto temp = HighsSolveProto(std::move(request));
623+
const auto temp = HighsSolveProto(std::move(request), nullptr, &log_callback_);
625624
if (temp.ok()) {
626625
response_ = std::move(temp.value());
627626
}

0 commit comments

Comments
 (0)