Skip to content

Commit 1d4a3fb

Browse files
committed
Added error handling to finitediff
1 parent 3f97b77 commit 1d4a3fb

File tree

15 files changed

+628
-522
lines changed

15 files changed

+628
-522
lines changed

crates/finitediff/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ keywords = ["differentiation", "optimization", "math", "science"]
1313
categories = ["science"]
1414

1515
[dependencies]
16+
anyhow = "1.0"
1617
ndarray = { version = "0.15.0", optional = true }
1718
num = "0.4.1"

crates/finitediff/src/array/diff.rs

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,38 @@
55
// http://opensource.org/licenses/MIT>, at your option. This file may not be
66
// copied, modified, or distributed except according to those terms.
77

8+
use anyhow::Error;
89
use num::Float;
910
use num::FromPrimitive;
1011

1112
use crate::utils::mod_and_calc_const;
1213

13-
pub fn forward_diff_const<const N: usize, F>(x: &[F; N], f: &dyn Fn(&[F; N]) -> F) -> [F; N]
14+
pub fn forward_diff_const<const N: usize, F>(
15+
x: &[F; N],
16+
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
17+
) -> Result<[F; N], Error>
1418
where
1519
F: Float + FromPrimitive,
1620
{
17-
let fx = (f)(x);
21+
let fx = (f)(x)?;
1822
let mut xt = *x;
1923
let eps_sqrt = F::epsilon().sqrt();
2024
let mut out = [F::from_f64(0.0).unwrap(); N];
2125
out.iter_mut()
2226
.enumerate()
23-
.map(|(i, o)| {
24-
let fx1 = mod_and_calc_const(&mut xt, f, i, eps_sqrt);
27+
.map(|(i, o)| -> Result<_, Error> {
28+
let fx1 = mod_and_calc_const(&mut xt, f, i, eps_sqrt)?;
2529
*o = (fx1 - fx) / eps_sqrt;
30+
Ok(())
2631
})
2732
.count();
28-
out
33+
Ok(out)
2934
}
3035

31-
pub fn central_diff_const<const N: usize, F>(x: &[F; N], f: &dyn Fn(&[F; N]) -> F) -> [F; N]
36+
pub fn central_diff_const<const N: usize, F>(
37+
x: &[F; N],
38+
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
39+
) -> Result<[F; N], Error>
3240
where
3341
F: Float + FromPrimitive,
3442
{
@@ -37,13 +45,14 @@ where
3745
let mut out = [F::from_f64(0.0).unwrap(); N];
3846
out.iter_mut()
3947
.enumerate()
40-
.map(|(i, o)| {
41-
let fx1 = mod_and_calc_const(&mut xt, f, i, eps_cbrt);
42-
let fx2 = mod_and_calc_const(&mut xt, f, i, -eps_cbrt);
48+
.map(|(i, o)| -> Result<_, Error> {
49+
let fx1 = mod_and_calc_const(&mut xt, f, i, eps_cbrt)?;
50+
let fx2 = mod_and_calc_const(&mut xt, f, i, -eps_cbrt)?;
4351
*o = (fx1 - fx2) / (F::from_f64(2.0).unwrap() * eps_cbrt);
52+
Ok(())
4453
})
4554
.count();
46-
out
55+
Ok(out)
4756
}
4857

4958
#[cfg(test)]
@@ -52,26 +61,26 @@ mod tests {
5261

5362
const COMP_ACC: f64 = 1e-6;
5463

55-
fn f(x: &[f64; 2]) -> f64 {
56-
x[0] + x[1].powi(2)
64+
fn f(x: &[f64; 2]) -> Result<f64, Error> {
65+
Ok(x[0] + x[1].powi(2))
5766
}
5867

59-
fn f2(x: &[f64; 2]) -> f64 {
60-
x[0] + x[1].powi(2)
68+
fn f2(x: &[f64; 2]) -> Result<f64, Error> {
69+
Ok(x[0] + x[1].powi(2))
6170
}
6271

6372
#[test]
6473
fn test_forward_diff_const_f64() {
6574
let p = [1.0f64, 1.0f64];
66-
let grad = forward_diff_const(&p, &f2);
75+
let grad = forward_diff_const(&p, &f2).unwrap();
6776
let res = [1.0f64, 2.0];
6877

6978
(0..2)
7079
.map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
7180
.count();
7281

7382
let p = [1.0f64, 2.0f64];
74-
let grad = forward_diff_const(&p, &f2);
83+
let grad = forward_diff_const(&p, &f2).unwrap();
7584
let res = [1.0f64, 4.0];
7685

7786
(0..2)
@@ -82,15 +91,15 @@ mod tests {
8291
#[test]
8392
fn test_central_diff_vec_f64() {
8493
let p = [1.0f64, 1.0f64];
85-
let grad = central_diff_const(&p, &f);
94+
let grad = central_diff_const(&p, &f).unwrap();
8695
let res = [1.0f64, 2.0];
8796

8897
(0..2)
8998
.map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
9099
.count();
91100

92101
let p = [1.0f64, 2.0f64];
93-
let grad = central_diff_const(&p, &f);
102+
let grad = central_diff_const(&p, &f).unwrap();
94103
let res = [1.0f64, 4.0];
95104

96105
(0..2)

0 commit comments

Comments
 (0)