Skip to content

Commit 9821c93

Browse files
committed
Disable execution counting by default
1 parent d0022c9 commit 9821c93

File tree

7 files changed

+65
-15
lines changed

7 files changed

+65
-15
lines changed

crates/argmin/src/core/executor.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ where
313313
mode: ObserverMode,
314314
) -> Self {
315315
self.observers.push(observer, mode);
316+
let state = self.state.take().unwrap();
317+
self.state = Some(state.set_counting(true));
316318
self.timer = true;
317319
self
318320
}

crates/argmin/src/core/state/iterstate.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ pub struct IterState<P, G, J, H, R, F> {
8282
pub max_iters: u64,
8383
/// Evaluation counts
8484
pub counts: HashMap<String, u64>,
85+
/// Update evaluation counts?
86+
pub counting_enabled: bool,
8587
/// Time required so far
8688
pub time: Option<instant::Duration>,
8789
/// Status of optimization execution
@@ -1040,6 +1042,7 @@ where
10401042
last_best_iter: 0,
10411043
max_iters: std::u64::MAX,
10421044
counts: HashMap::new(),
1045+
counting_enabled: false,
10431046
time: Some(instant::Duration::new(0, 0)),
10441047
termination_status: TerminationStatus::NotTerminated,
10451048
}
@@ -1339,7 +1342,7 @@ where
13391342
/// ```
13401343
/// # use std::collections::HashMap;
13411344
/// # use argmin::core::{Problem, IterState, State, ArgminFloat};
1342-
/// # let mut state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
1345+
/// # let mut state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new().set_counting(true);
13431346
/// # assert_eq!(state.counts, HashMap::new());
13441347
/// # state.counts.insert("test2".to_string(), 10u64);
13451348
/// #
@@ -1356,9 +1359,11 @@ where
13561359
/// # assert_eq!(state.counts, hm);
13571360
/// ```
13581361
fn func_counts<O>(&mut self, problem: &Problem<O>) {
1359-
for (k, &v) in problem.counts.iter() {
1360-
let count = self.counts.entry(k.to_string()).or_insert(0);
1361-
*count = v
1362+
if self.counting_enabled {
1363+
for (k, &v) in problem.counts.iter() {
1364+
let count = self.counts.entry(k.to_string()).or_insert(0);
1365+
*count = v
1366+
}
13621367
}
13631368
}
13641369

@@ -1401,6 +1406,13 @@ where
14011406
fn is_best(&self) -> bool {
14021407
self.last_best_iter == self.iter
14031408
}
1409+
1410+
/// Overrides state of counting function executions, by default - only if needed
1411+
#[must_use]
1412+
fn set_counting(mut self, mode: bool) -> Self {
1413+
self.counting_enabled = mode;
1414+
self
1415+
}
14041416
}
14051417

14061418
#[cfg(test)]

crates/argmin/src/core/state/linearprogramstate.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ pub struct LinearProgramState<P, F> {
5757
pub max_iters: u64,
5858
/// Evaluation counts
5959
pub counts: HashMap<String, u64>,
60+
/// Update evaluation counts?
61+
pub counting_enabled: bool,
6062
/// Time required so far
6163
pub time: Option<instant::Duration>,
6264
/// Status of optimization execution
@@ -206,6 +208,7 @@ where
206208
last_best_iter: 0,
207209
max_iters: std::u64::MAX,
208210
counts: HashMap::new(),
211+
counting_enabled: false,
209212
time: Some(instant::Duration::new(0, 0)),
210213
termination_status: TerminationStatus::NotTerminated,
211214
}
@@ -504,7 +507,7 @@ where
504507
/// ```
505508
/// # use std::collections::HashMap;
506509
/// # use argmin::core::{Problem, LinearProgramState, State, ArgminFloat};
507-
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new();
510+
/// # let mut state: LinearProgramState<Vec<f64>, f64> = LinearProgramState::new().set_counting(true);
508511
/// # assert_eq!(state.counts, HashMap::new());
509512
/// # state.counts.insert("test2".to_string(), 10u64);
510513
/// #
@@ -521,9 +524,11 @@ where
521524
/// # assert_eq!(state.counts, hm);
522525
/// ```
523526
fn func_counts<O>(&mut self, problem: &Problem<O>) {
524-
for (k, &v) in problem.counts.iter() {
525-
let count = self.counts.entry(k.to_string()).or_insert(0);
526-
*count = v
527+
if self.counting_enabled {
528+
for (k, &v) in problem.counts.iter() {
529+
let count = self.counts.entry(k.to_string()).or_insert(0);
530+
*count = v
531+
}
527532
}
528533
}
529534

@@ -566,4 +571,11 @@ where
566571
fn is_best(&self) -> bool {
567572
self.last_best_iter == self.iter
568573
}
574+
575+
/// Overrides state of counting function executions, by default - only if needed
576+
#[must_use]
577+
fn set_counting(mut self, mode: bool) -> Self {
578+
self.counting_enabled = mode;
579+
self
580+
}
569581
}

crates/argmin/src/core/state/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ pub trait State {
9090
/// Returns current cost function evaluation count
9191
fn get_func_counts(&self) -> &HashMap<String, u64>;
9292

93+
/// Set function counting status
94+
fn set_counting(self, mode: bool) -> Self;
95+
9396
/// Set time required since the beginning of the optimization until the current iteration
9497
fn time(&mut self, time: Option<instant::Duration>) -> &mut Self;
9598

crates/argmin/src/core/state/populationstate.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ pub struct PopulationState<P, F> {
5959
pub max_iters: u64,
6060
/// Evaluation counts
6161
pub counts: HashMap<String, u64>,
62+
/// Update evaluation counts?
63+
pub counting_enabled: bool,
6264
/// Time required so far
6365
pub time: Option<instant::Duration>,
6466
/// Status of optimization execution
@@ -484,6 +486,7 @@ where
484486
last_best_iter: 0,
485487
max_iters: std::u64::MAX,
486488
counts: HashMap::new(),
489+
counting_enabled: false,
487490
time: Some(instant::Duration::new(0, 0)),
488491
termination_status: TerminationStatus::NotTerminated,
489492
}
@@ -783,7 +786,7 @@ where
783786
/// ```
784787
/// # use std::collections::HashMap;
785788
/// # use argmin::core::{Problem, PopulationState, State, ArgminFloat};
786-
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new();
789+
/// # let mut state: PopulationState<Vec<f64>, f64> = PopulationState::new().set_counting(true);
787790
/// # assert_eq!(state.counts, HashMap::new());
788791
/// # state.counts.insert("test2".to_string(), 10u64);
789792
/// #
@@ -800,9 +803,11 @@ where
800803
/// # assert_eq!(state.counts, hm);
801804
/// ```
802805
fn func_counts<O>(&mut self, problem: &Problem<O>) {
803-
for (k, &v) in problem.counts.iter() {
804-
let count = self.counts.entry(k.to_string()).or_insert(0);
805-
*count = v
806+
if self.counting_enabled {
807+
for (k, &v) in problem.counts.iter() {
808+
let count = self.counts.entry(k.to_string()).or_insert(0);
809+
*count = v
810+
}
806811
}
807812
}
808813

@@ -845,6 +850,12 @@ where
845850
fn is_best(&self) -> bool {
846851
self.last_best_iter == self.iter
847852
}
853+
854+
/// Overrides state of counting function executions, by default - only if needed
855+
fn set_counting(mut self, mode: bool) -> Self {
856+
self.counting_enabled = mode;
857+
self
858+
}
848859
}
849860

850861
// TODO: Tests? Actually doc tests should already cover everything.

crates/argmin/src/solver/brent/brentopt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ mod tests {
230230
let cost = TestFunc {};
231231
let solver = BrentOpt::new(-10., 10.);
232232
let res = Executor::new(cost, solver)
233-
.configure(|state| state.max_iters(13))
233+
.configure(|state| state.set_counting(true).max_iters(13))
234234
.run()
235235
.unwrap();
236236
assert_eq!(

crates/argmin/src/solver/linesearch/backtracking.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,12 @@ mod tests {
640640
ls.search_direction(vec![2.0f64, 0.0]);
641641

642642
let data = Executor::new(prob, ls.clone())
643-
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
643+
.configure(|config| {
644+
config
645+
.set_counting(true)
646+
.param(ls.init_param.clone().unwrap())
647+
.max_iters(10)
648+
})
644649
.run();
645650
assert!(data.is_ok());
646651

@@ -689,7 +694,12 @@ mod tests {
689694
ls.search_direction(vec![2.0f64, 0.0]);
690695

691696
let data = Executor::new(prob, ls.clone())
692-
.configure(|config| config.param(ls.init_param.clone().unwrap()).max_iters(10))
697+
.configure(|config| {
698+
config
699+
.set_counting(true)
700+
.param(ls.init_param.clone().unwrap())
701+
.max_iters(10)
702+
})
693703
.run();
694704
assert!(data.is_ok());
695705

0 commit comments

Comments
 (0)