Skip to content

Commit a9e3eb7

Browse files
pacakstefan-k
authored andcommitted
Disable counting by default
1 parent 6492020 commit a9e3eb7

File tree

6 files changed

+79
-12
lines changed

6 files changed

+79
-12
lines changed

crates/argmin/src/core/state/iterstate.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,20 @@ where
971971
pub fn take_prev_residuals(&mut self) -> Option<R> {
972972
self.prev_residuals.take()
973973
}
974+
975+
/// Overrides state of counting function executions (default: false)
976+
/// ```
977+
/// # use argmin::core::{IterState, State};
978+
/// # let mut state: IterState<(), (), (), (), Vec<f64>, f64> = IterState::new();
979+
/// # assert!(!state.counting_enabled);
980+
/// let state = state.counting(true);
981+
/// # assert!(state.counting_enabled);
982+
/// ```
983+
#[must_use]
984+
pub fn counting(mut self, mode: bool) -> Self {
985+
self.counting_enabled = mode;
986+
self
987+
}
974988
}
975989

976990
impl<P, G, J, H, R, F> State for IterState<P, G, J, H, R, F>

crates/argmin/src/core/state/linearprogramstate.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ pub struct LinearProgramState<P, F> {
5656
pub max_iters: u64,
5757
/// Evaluation counts
5858
pub counts: HashMap<String, u64>,
59+
/// Update evaluation counts?
60+
pub counting_enabled: bool,
5961
/// Time required so far
6062
pub time: Option<instant::Duration>,
6163
/// Status of optimization execution
@@ -150,6 +152,20 @@ impl<P, F> LinearProgramState<P, F> {
150152
self.cost = cost;
151153
self
152154
}
155+
156+
/// Overrides state of counting function executions (default: false)
157+
/// ```
158+
/// # use argmin::core::{State, LinearProgramState};
159+
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
160+
/// # assert!(!state.counting_enabled);
161+
/// let state = state.counting(true);
162+
/// # assert!(state.counting_enabled);
163+
/// ```
164+
#[must_use]
165+
pub fn counting(mut self, mode: bool) -> Self {
166+
self.counting_enabled = mode;
167+
self
168+
}
153169
}
154170

155171
impl<P, F> State for LinearProgramState<P, F>
@@ -205,6 +221,7 @@ where
205221
last_best_iter: 0,
206222
max_iters: std::u64::MAX,
207223
counts: HashMap::new(),
224+
counting_enabled: false,
208225
time: Some(instant::Duration::new(0, 0)),
209226
termination_status: TerminationStatus::NotTerminated,
210227
}
@@ -503,7 +520,7 @@ where
503520
/// ```
504521
/// # use std::collections::HashMap;
505522
/// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat};
506-
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
523+
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new().counting(true);
507524
/// # assert_eq!(state.counts, HashMap::new());
508525
/// # state.counts.insert("test2".to_string(), 10u64);
509526
/// #
@@ -520,9 +537,11 @@ where
520537
/// # assert_eq!(state.counts, hm);
521538
/// ```
522539
fn func_counts<O>(&mut self, problem: &Problem<O>) {
523-
for (k, &v) in problem.counts.iter() {
524-
let count = self.counts.entry(k.to_string()).or_insert(0);
525-
*count = v
540+
if self.counting_enabled {
541+
for (k, &v) in problem.counts.iter() {
542+
let count = self.counts.entry(k.to_string()).or_insert(0);
543+
*count = v
544+
}
526545
}
527546
}
528547

crates/argmin/src/core/state/populationstate.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ pub struct PopulationState<P, F> {
5858
pub max_iters: u64,
5959
/// Evaluation counts
6060
pub counts: HashMap<String, u64>,
61+
/// Update evaluation counts?
62+
pub counting_enabled: bool,
6163
/// Time required so far
6264
pub time: Option<instant::Duration>,
6365
/// Status of optimization execution
@@ -429,6 +431,20 @@ where
429431
pub fn take_population(&mut self) -> Option<Vec<P>> {
430432
self.population.take()
431433
}
434+
435+
/// Overrides state of counting function executions (default: false)
436+
/// ```
437+
/// # use argmin::core::{State, PopulationState};
438+
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
439+
/// # assert!(!state.counting_enabled);
440+
/// let state = state.counting(true);
441+
/// # assert!(state.counting_enabled);
442+
/// ```
443+
#[must_use]
444+
pub fn counting(mut self, mode: bool) -> Self {
445+
self.counting_enabled = mode;
446+
self
447+
}
432448
}
433449

434450
impl<P, F> State for PopulationState<P, F>
@@ -483,6 +499,7 @@ where
483499
last_best_iter: 0,
484500
max_iters: std::u64::MAX,
485501
counts: HashMap::new(),
502+
counting_enabled: false,
486503
time: Some(instant::Duration::new(0, 0)),
487504
termination_status: TerminationStatus::NotTerminated,
488505
}
@@ -782,7 +799,7 @@ where
782799
/// ```
783800
/// # use std::collections::HashMap;
784801
/// # use argmin::core::{Problem, PopulationState, State, ArgminFloat};
785-
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
802+
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new().counting(true);
786803
/// # assert_eq!(state.counts, HashMap::new());
787804
/// # state.counts.insert("test2".to_string(), 10u64);
788805
/// #
@@ -799,9 +816,11 @@ where
799816
/// # assert_eq!(state.counts, hm);
800817
/// ```
801818
fn func_counts<O>(&mut self, problem: &Problem<O>) {
802-
for (k, &v) in problem.counts.iter() {
803-
let count = self.counts.entry(k.to_string()).or_insert(0);
804-
*count = v
819+
if self.counting_enabled {
820+
for (k, &v) in problem.counts.iter() {
821+
let count = self.counts.entry(k.to_string()).or_insert(0);
822+
*count = v
823+
}
805824
}
806825
}
807826

crates/argmin/src/solver/brent/brentopt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ mod tests {
231231
let cost = TestFunc {};
232232
let solver = BrentOpt::new(-10., 10.);
233233
let res = Executor::new(cost, solver)
234-
.configure(|state| state.max_iters(13))
234+
.configure(|state| state.counting(true).max_iters(13))
235235
.run()
236236
.unwrap();
237237
assert_eq!(

crates/argmin/src/solver/linesearch/backtracking.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,12 @@ mod tests {
640640
ls.search_direction(vec![2.0f64, 0.0]);
641641

642642
let data = Executor::new(prob, ls.clone())
643-
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
643+
.configure(|config| {
644+
config
645+
.counting(true)
646+
.param(ls.init_param.clone().unwrap())
647+
.max_iters(10)
648+
})
644649
.run();
645650
assert!(data.is_ok());
646651

@@ -689,7 +694,12 @@ mod tests {
689694
ls.search_direction(vec![2.0f64, 0.0]);
690695

691696
let data = Executor::new(prob, ls.clone())
692-
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
697+
.configure(|config| {
698+
config
699+
.counting(true)
700+
.param(ls.init_param.clone().unwrap())
701+
.max_iters(10)
702+
})
693703
.run();
694704
assert!(data.is_ok());
695705

crates/argmin/src/tests.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,12 @@ fn test_lbfgs_func_count() {
161161
let linesearch = MoreThuenteLineSearch::new();
162162
let solver = LBFGS::new(linesearch, 10);
163163
let res = Executor::new(cost.clone(), solver)
164-
.configure(|config| config.param(cost.param_init.clone()).max_iters(100))
164+
.configure(|config| {
165+
config
166+
.param(cost.param_init.clone())
167+
.max_iters(100)
168+
.counting(true)
169+
})
165170
.run()
166171
.unwrap();
167172

0 commit comments

Comments
 (0)