Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.79
1.82
159 changes: 127 additions & 32 deletions src/lp_extract.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use good_lp::{
default_solver, variable, variables, Expression, Solution, SolutionStatus, Solver, SolverModel,
Variable,
default_solver, solvers::WithTimeLimit, variable, variables, Expression, Solution,
SolutionStatus, Solver, SolverModel, Variable,
};
use std::time::Instant;

Expand Down Expand Up @@ -71,7 +71,7 @@ impl<L: Language, N: Analysis<L>> LpCostFunction<L, N> for AstSize {
///
/// See the (`good_lp` documentation)[https://docs.rs/good_lp/1/good_lp/solvers/index.html]
///
/// At run time, select the solver by calling [`Self::solve_with`] or [`Self::solve_multiple_with`]
/// At run time, select the solver by calling [`Self::solve_with`], [`Self::solve_multiple_with`], [`Self::solve_with_timeout`], or [`Self::solve_multiple_with_timeout`]
/// and passing one of the enabled `good_lp` solver implementations.
///
/// - Example (CBC):
Expand Down Expand Up @@ -154,21 +154,24 @@ where
self.solve_multiple_with(&[root], solver).0
}

/// Extract a single rooted term with an explicit solver backend and time limit.
pub fn solve_with_timeout<S: Solver>(&mut self, root: Id, solver: S, timeout: f64) -> RecExpr<L>
where
<S as Solver>::Model: WithTimeLimit,
{
self.solve_multiple_with_timeout(&[root], solver, timeout).0
}

/// Extract (potentially multiple) roots
pub fn solve_multiple(&mut self, roots: &[Id]) -> (RecExpr<L>, Vec<Id>) {
self.solve_multiple_with(roots, default_solver)
}

/// Like [`solve_multiple`], but lets the caller provide a `good_lp` solver backend.
/// Example: `solve_multiple_with(roots, good_lp::highs)`.
pub fn solve_multiple_with<S: Solver>(
&mut self,
roots: &[Id],
solver: S,
) -> (RecExpr<L>, Vec<Id>) {
/// Builds the ILP model with variables and objective function.
/// Returns the model builder (before timeout) and the variables map.
fn build_ilp_model<S: Solver>(&mut self, solver: S) -> (S::Model, HashMap<Id, ClassVars>) {
let egraph = self.egraph;
let mut num_vars: usize = 0;
let mut num_cons: usize = 0;

// Build variables per class
let mut builder = variables!();
Expand Down Expand Up @@ -205,11 +208,25 @@ where
}

// Build model using the provided solver
let mut model = builder.minimise(objective).using(solver);
let model = builder.minimise(objective).using(solver);

log::info!("Model using {num_vars} variables");
(model, vars)
}

/// Adds all constraints to the model.
fn add_constraints<S: Solver>(
&self,
model: &mut S::Model,
vars: &HashMap<Id, ClassVars>,
roots: &[Id],
) {
let egraph = self.egraph;
let mut num_cons: usize = 0;

// Constraints:
// - Exactly one chosen node per active class: sum(nodes) == active
for (&id, class) in &vars {
for (&id, class) in vars {
let sum_nodes: Expression = class
.nodes
.iter()
Expand Down Expand Up @@ -240,26 +257,16 @@ where
model.add_constraint(Expression::from(vars[root].active).geq(1));
}

log::info!("Model using {num_vars} variables and {num_cons} constraints");
log::info!("Solving using {}", <S as Solver>::name(),);
let start = Instant::now();
let solution = model
.solve()
.expect("good_lp failed to solve the ILP problem");
let duration = start.elapsed().as_secs_f64();
log::info!("Solution found in {:.2}s", duration);
match solution.status() {
SolutionStatus::Optimal => {
log::info!("Solution is optimal");
}
SolutionStatus::TimeLimit => {
log::warn!("Solver timed out, solution may not be optimal.");
}
SolutionStatus::GapLimit => {
log::info!("Solver reached gap limit, solution may not be optimal.");
}
};
log::info!("Model using {num_cons} constraints");
}

/// Extracts the solution from the solved model.
fn extract_solution<S: Solver>(
&self,
solution: <S::Model as SolverModel>::Solution,
vars: &HashMap<Id, ClassVars>,
roots: &[Id],
) -> (RecExpr<L>, Vec<Id>) {
let mut todo: Vec<Id> = roots.iter().map(|id| self.egraph.find(*id)).collect();
let mut expr = RecExpr::default();
// converts e-class ids to e-node ids
Expand Down Expand Up @@ -296,6 +303,78 @@ where
assert!(expr.is_dag(), "LpExtract found a cyclic term!: {:?}", expr);
(expr, root_idxs)
}

/// Like [`solve_multiple`], but lets the caller provide a `good_lp` solver backend.
/// Example: `solve_multiple_with(roots, good_lp::highs)`.
pub fn solve_multiple_with<S: Solver>(
&mut self,
roots: &[Id],
solver: S,
) -> (RecExpr<L>, Vec<Id>) {
let (mut model, vars) = self.build_ilp_model(solver);
self.add_constraints::<S>(&mut model, &vars, roots);

log::info!("Solving using {}", <S as Solver>::name());
let start = Instant::now();
let solution = model
.solve()
.expect("good_lp failed to solve the ILP problem");
let duration = start.elapsed().as_secs_f64();
log::info!("Solution found in {:.2}s", duration);
match solution.status() {
SolutionStatus::Optimal => {
log::info!("Solution is optimal");
}
SolutionStatus::TimeLimit => {
log::warn!("Solver timed out, solution may not be optimal.");
}
SolutionStatus::GapLimit => {
log::info!("Solver reached gap limit, solution may not be optimal.");
}
};

self.extract_solution::<S>(solution, &vars, roots)
}

/// Like [`solve_multiple_with`], but lets the caller provide a time limit for the 'good_lp' solver in seconds.
/// Example: `solve_multiple_with_timeout(roots, good_lp::highs, 600.0)`.
pub fn solve_multiple_with_timeout<S: Solver>(
&mut self,
roots: &[Id],
solver: S,
timeout: f64,
) -> (RecExpr<L>, Vec<Id>)
where
<S as Solver>::Model: WithTimeLimit,
{
let (model_build, vars) = self.build_ilp_model(solver);

// Set timeout
let mut model = model_build.with_time_limit(timeout);

self.add_constraints::<S>(&mut model, &vars, roots);

log::info!("Solving using {}", <S as Solver>::name());
let start = Instant::now();
let solution = model
.solve()
.expect("good_lp failed to solve the ILP problem");
let duration = start.elapsed().as_secs_f64();
log::info!("Solution found in {:.2}s", duration);
match solution.status() {
SolutionStatus::Optimal => {
log::info!("Solution is optimal");
}
SolutionStatus::TimeLimit => {
log::warn!("Solver timed out, solution may not be optimal.");
}
SolutionStatus::GapLimit => {
log::info!("Solver reached gap limit, solution may not be optimal.");
}
};

self.extract_solution::<S>(solution, &vars, roots)
}
}

fn find_cycles<L, N>(egraph: &EGraph<L, N>, mut f: impl FnMut(Id, usize))
Expand Down Expand Up @@ -352,6 +431,22 @@ mod tests {
assert_eq!(ids.len(), 2);
}

#[test]
fn simple_lp_extract_two_timeout() {
let mut egraph = EGraph::<S, ()>::default();
let a = egraph.add(S::leaf("a"));
let plus = egraph.add(S::new("+", vec![a, a]));
let f = egraph.add(S::new("f", vec![plus]));
let g = egraph.add(S::new("g", vec![plus]));

let mut ext = LpExtractor::new(&egraph, AstSize);
let (exp, ids) = ext.solve_multiple_with_timeout(&[f, g], good_lp::coin_cbc, 10.0);
println!("{:?}", exp);
println!("{}", exp);
assert_eq!(exp.len(), 4);
assert_eq!(ids.len(), 2);
}

#[test]
fn extract_root_mismatch() {
let mut egraph = EGraph::<S, ()>::default();
Expand Down