Skip to content

Commit 16df1c5

Browse files
committed
disable counting by default
1 parent ac13db9 commit 16df1c5

File tree

5 files changed

+73
-11
lines changed

5 files changed

+73
-11
lines changed

crates/argmin/src/core/state/iterstate.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,20 @@ where
972972
pub fn take_prev_residuals(&mut self) -> Option<R> {
973973
self.prev_residuals.take()
974974
}
975+
976+
/// Overrides state of counting function executions (default: false)
977+
/// ```
978+
/// # use argmin::core::{IterState, State};
979+
/// # let mut state: IterState<(), (), (), (), Vec<f64>, f64> = IterState::new();
980+
/// # assert!(!state.counting_enabled);
981+
/// let state = state.counting(true);
982+
/// # assert!(state.counting_enabled);
983+
/// ```
984+
#[must_use]
985+
pub fn counting(mut self, mode: bool) -> Self {
986+
self.counting_enabled = mode;
987+
self
988+
}
975989
}
976990

977991
impl<P, G, J, H, R, F> State for IterState<P, G, J, H, R, F>

crates/argmin/src/core/state/linearprogramstate.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ pub struct LinearProgramState<P, F> {
5757
pub max_iters: u64,
5858
/// Evaluation counts
5959
pub counts: HashMap<String, u64>,
60+
/// Update evaluation counts?
61+
pub counting_enabled: bool,
6062
/// Time required so far
6163
pub time: Option<instant::Duration>,
6264
/// Status of optimization execution
@@ -151,6 +153,20 @@ impl<P, F> LinearProgramState<P, F> {
151153
self.cost = cost;
152154
self
153155
}
156+
157+
/// Overrides state of counting function executions (default: false)
158+
/// ```
159+
/// # use argmin::core::{State, LinearProgramState};
160+
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
161+
/// # assert!(!state.counting_enabled);
162+
/// let state = state.counting(true);
163+
/// # assert!(state.counting_enabled);
164+
/// ```
165+
#[must_use]
166+
pub fn counting(mut self, mode: bool) -> Self {
167+
self.counting_enabled = mode;
168+
self
169+
}
154170
}
155171

156172
impl<P, F> State for LinearProgramState<P, F>
@@ -206,6 +222,7 @@ where
206222
last_best_iter: 0,
207223
max_iters: std::u64::MAX,
208224
counts: HashMap::new(),
225+
counting_enabled: false,
209226
time: Some(instant::Duration::new(0, 0)),
210227
termination_status: TerminationStatus::NotTerminated,
211228
}
@@ -504,7 +521,7 @@ where
504521
/// ```
505522
/// # use std::collections::HashMap;
506523
/// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat};
507-
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
524+
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new().counting(true);
508525
/// # assert_eq!(state.counts, HashMap::new());
509526
/// # state.counts.insert("test2".to_string(), 10u64);
510527
/// #
@@ -521,9 +538,11 @@ where
521538
/// # assert_eq!(state.counts, hm);
522539
/// ```
523540
fn func_counts<O>(&mut self, problem: &Problem<O>) {
524-
for (k, &v) in problem.counts.iter() {
525-
let count = self.counts.entry(k.to_string()).or_insert(0);
526-
*count = v
541+
if self.counting_enabled {
542+
for (k, &v) in problem.counts.iter() {
543+
let count = self.counts.entry(k.to_string()).or_insert(0);
544+
*count = v
545+
}
527546
}
528547
}
529548

crates/argmin/src/core/state/populationstate.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ pub struct PopulationState<P, F> {
5959
pub max_iters: u64,
6060
/// Evaluation counts
6161
pub counts: HashMap<String, u64>,
62+
/// Update evaluation counts?
63+
pub counting_enabled: bool,
6264
/// Time required so far
6365
pub time: Option<instant::Duration>,
6466
/// Status of optimization execution
@@ -430,6 +432,20 @@ where
430432
pub fn take_population(&mut self) -> Option<Vec<P>> {
431433
self.population.take()
432434
}
435+
436+
/// Overrides state of counting function executions (default: false)
437+
/// ```
438+
/// # use argmin::core::{State, PopulationState};
439+
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
440+
/// # assert!(!state.counting_enabled);
441+
/// let state = state.counting(true);
442+
/// # assert!(state.counting_enabled);
443+
/// ```
444+
#[must_use]
445+
pub fn counting(mut self, mode: bool) -> Self {
446+
self.counting_enabled = mode;
447+
self
448+
}
433449
}
434450

435451
impl<P, F> State for PopulationState<P, F>
@@ -484,6 +500,7 @@ where
484500
last_best_iter: 0,
485501
max_iters: std::u64::MAX,
486502
counts: HashMap::new(),
503+
counting_enabled: false,
487504
time: Some(instant::Duration::new(0, 0)),
488505
termination_status: TerminationStatus::NotTerminated,
489506
}
@@ -783,7 +800,7 @@ where
783800
/// ```
784801
/// # use std::collections::HashMap;
785802
/// # use argmin::core::{Problem, PopulationState, State, ArgminFloat};
786-
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
803+
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new().counting(true);
787804
/// # assert_eq!(state.counts, HashMap::new());
788805
/// # state.counts.insert("test2".to_string(), 10u64);
789806
/// #
@@ -800,9 +817,11 @@ where
800817
/// # assert_eq!(state.counts, hm);
801818
/// ```
802819
fn func_counts<O>(&mut self, problem: &Problem<O>) {
803-
for (k, &v) in problem.counts.iter() {
804-
let count = self.counts.entry(k.to_string()).or_insert(0);
805-
*count = v
820+
if self.counting_enabled {
821+
for (k, &v) in problem.counts.iter() {
822+
let count = self.counts.entry(k.to_string()).or_insert(0);
823+
*count = v
824+
}
806825
}
807826
}
808827

crates/argmin/src/solver/brent/brentopt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ mod tests {
230230
let cost = TestFunc {};
231231
let solver = BrentOpt::new(-10., 10.);
232232
let res = Executor::new(cost, solver)
233-
.configure(|state| state.max_iters(13))
233+
.configure(|state| state.counting(true).max_iters(13))
234234
.run()
235235
.unwrap();
236236
assert_eq!(

crates/argmin/src/solver/linesearch/backtracking.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,12 @@ mod tests {
640640
ls.search_direction(vec![2.0f64, 0.0]);
641641

642642
let data = Executor::new(prob, ls.clone())
643-
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
643+
.configure(|config| {
644+
config
645+
.counting(true)
646+
.param(ls.init_param.clone().unwrap())
647+
.max_iters(10)
648+
})
644649
.run();
645650
assert!(data.is_ok());
646651

@@ -689,7 +694,12 @@ mod tests {
689694
ls.search_direction(vec![2.0f64, 0.0]);
690695

691696
let data = Executor::new(prob, ls.clone())
692-
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
697+
.configure(|config| {
698+
config
699+
.counting(true)
700+
.param(ls.init_param.clone().unwrap())
701+
.max_iters(10)
702+
})
693703
.run();
694704
assert!(data.is_ok());
695705

0 commit comments

Comments
 (0)