5
5
// http://opensource.org/licenses/MIT>, at your option. This file may not be
6
6
// copied, modified, or distributed except according to those terms.
7
7
8
+ use anyhow:: Error ;
8
9
use num:: Float ;
9
10
use num:: FromPrimitive ;
10
11
11
12
use crate :: utils:: mod_and_calc_const;
12
13
13
- pub fn forward_diff_const < const N : usize , F > ( x : & [ F ; N ] , f : & dyn Fn ( & [ F ; N ] ) -> F ) -> [ F ; N ]
14
+ pub fn forward_diff_const < const N : usize , F > (
15
+ x : & [ F ; N ] ,
16
+ f : & dyn Fn ( & [ F ; N ] ) -> Result < F , Error > ,
17
+ ) -> Result < [ F ; N ] , Error >
14
18
where
15
19
F : Float + FromPrimitive ,
16
20
{
17
- let fx = ( f) ( x) ;
21
+ let fx = ( f) ( x) ? ;
18
22
let mut xt = * x;
19
23
let eps_sqrt = F :: epsilon ( ) . sqrt ( ) ;
20
24
let mut out = [ F :: from_f64 ( 0.0 ) . unwrap ( ) ; N ] ;
21
25
out. iter_mut ( )
22
26
. enumerate ( )
23
- . map ( |( i, o) | {
24
- let fx1 = mod_and_calc_const ( & mut xt, f, i, eps_sqrt) ;
27
+ . map ( |( i, o) | -> Result < _ , Error > {
28
+ let fx1 = mod_and_calc_const ( & mut xt, f, i, eps_sqrt) ? ;
25
29
* o = ( fx1 - fx) / eps_sqrt;
30
+ Ok ( ( ) )
26
31
} )
27
32
. count ( ) ;
28
- out
33
+ Ok ( out)
29
34
}
30
35
31
- pub fn central_diff_const < const N : usize , F > ( x : & [ F ; N ] , f : & dyn Fn ( & [ F ; N ] ) -> F ) -> [ F ; N ]
36
+ pub fn central_diff_const < const N : usize , F > (
37
+ x : & [ F ; N ] ,
38
+ f : & dyn Fn ( & [ F ; N ] ) -> Result < F , Error > ,
39
+ ) -> Result < [ F ; N ] , Error >
32
40
where
33
41
F : Float + FromPrimitive ,
34
42
{
@@ -37,13 +45,14 @@ where
37
45
let mut out = [ F :: from_f64 ( 0.0 ) . unwrap ( ) ; N ] ;
38
46
out. iter_mut ( )
39
47
. enumerate ( )
40
- . map ( |( i, o) | {
41
- let fx1 = mod_and_calc_const ( & mut xt, f, i, eps_cbrt) ;
42
- let fx2 = mod_and_calc_const ( & mut xt, f, i, -eps_cbrt) ;
48
+ . map ( |( i, o) | -> Result < _ , Error > {
49
+ let fx1 = mod_and_calc_const ( & mut xt, f, i, eps_cbrt) ? ;
50
+ let fx2 = mod_and_calc_const ( & mut xt, f, i, -eps_cbrt) ? ;
43
51
* o = ( fx1 - fx2) / ( F :: from_f64 ( 2.0 ) . unwrap ( ) * eps_cbrt) ;
52
+ Ok ( ( ) )
44
53
} )
45
54
. count ( ) ;
46
- out
55
+ Ok ( out)
47
56
}
48
57
49
58
#[ cfg( test) ]
@@ -52,26 +61,26 @@ mod tests {
52
61
53
62
const COMP_ACC : f64 = 1e-6 ;
54
63
55
- fn f ( x : & [ f64 ; 2 ] ) -> f64 {
56
- x[ 0 ] + x[ 1 ] . powi ( 2 )
64
+ fn f ( x : & [ f64 ; 2 ] ) -> Result < f64 , Error > {
65
+ Ok ( x[ 0 ] + x[ 1 ] . powi ( 2 ) )
57
66
}
58
67
59
- fn f2 ( x : & [ f64 ; 2 ] ) -> f64 {
60
- x[ 0 ] + x[ 1 ] . powi ( 2 )
68
+ fn f2 ( x : & [ f64 ; 2 ] ) -> Result < f64 , Error > {
69
+ Ok ( x[ 0 ] + x[ 1 ] . powi ( 2 ) )
61
70
}
62
71
63
72
#[ test]
64
73
fn test_forward_diff_const_f64 ( ) {
65
74
let p = [ 1.0f64 , 1.0f64 ] ;
66
- let grad = forward_diff_const ( & p, & f2) ;
75
+ let grad = forward_diff_const ( & p, & f2) . unwrap ( ) ;
67
76
let res = [ 1.0f64 , 2.0 ] ;
68
77
69
78
( 0 ..2 )
70
79
. map ( |i| assert ! ( ( res[ i] - grad[ i] ) . abs( ) < COMP_ACC ) )
71
80
. count ( ) ;
72
81
73
82
let p = [ 1.0f64 , 2.0f64 ] ;
74
- let grad = forward_diff_const ( & p, & f2) ;
83
+ let grad = forward_diff_const ( & p, & f2) . unwrap ( ) ;
75
84
let res = [ 1.0f64 , 4.0 ] ;
76
85
77
86
( 0 ..2 )
@@ -82,15 +91,15 @@ mod tests {
82
91
#[ test]
83
92
fn test_central_diff_vec_f64 ( ) {
84
93
let p = [ 1.0f64 , 1.0f64 ] ;
85
- let grad = central_diff_const ( & p, & f) ;
94
+ let grad = central_diff_const ( & p, & f) . unwrap ( ) ;
86
95
let res = [ 1.0f64 , 2.0 ] ;
87
96
88
97
( 0 ..2 )
89
98
. map ( |i| assert ! ( ( res[ i] - grad[ i] ) . abs( ) < COMP_ACC ) )
90
99
. count ( ) ;
91
100
92
101
let p = [ 1.0f64 , 2.0f64 ] ;
93
- let grad = central_diff_const ( & p, & f) ;
102
+ let grad = central_diff_const ( & p, & f) . unwrap ( ) ;
94
103
let res = [ 1.0f64 , 4.0 ] ;
95
104
96
105
( 0 ..2 )
0 commit comments