@@ -26,154 +26,149 @@ use jacobian::{
26
26
forward_jacobian_const, forward_jacobian_pert_const, forward_jacobian_vec_prod_const,
27
27
} ;
28
28
29
+ pub ( crate ) type CostFn < ' a , const N : usize , F > = & ' a dyn Fn ( & [ F ; N ] ) -> Result < F , Error > ;
30
+ pub ( crate ) type GradientFn < ' a , const N : usize , F > = & ' a dyn Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error > ;
31
+ pub ( crate ) type OpFn < ' a , const N : usize , const M : usize , F > =
32
+ & ' a dyn Fn ( & [ F ; N ] ) -> Result < [ F ; M ] , Error > ;
33
+
29
34
#[ inline( always) ]
30
- pub fn forward_diff < const N : usize , Func , F > ( f : Func ) -> impl Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error >
35
+ pub fn forward_diff < const N : usize , F > (
36
+ f : CostFn < ' _ , N , F > ,
37
+ ) -> impl Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error > + ' _
31
38
where
32
- Func : Fn ( & [ F ; N ] ) -> Result < F , Error > ,
33
39
F : Float + FromPrimitive ,
34
40
{
35
41
move |p : & [ F ; N ] | forward_diff_const ( p, & f)
36
42
}
37
43
38
44
#[ inline( always) ]
39
- pub fn central_diff < const N : usize , Func , F > ( f : Func ) -> impl Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error >
45
+ pub fn central_diff < const N : usize , F > (
46
+ f : CostFn < ' _ , N , F > ,
47
+ ) -> impl Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error > + ' _
40
48
where
41
- Func : Fn ( & [ F ; N ] ) -> Result < F , Error > ,
42
49
F : Float + FromPrimitive ,
43
50
{
44
51
move |p : & [ F ; N ] | central_diff_const ( p, & f)
45
52
}
46
53
47
54
#[ inline( always) ]
48
- pub fn forward_jacobian < const N : usize , const M : usize , Func , F > (
49
- f : Func ,
50
- ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; M ] , Error >
55
+ pub fn forward_jacobian < const N : usize , const M : usize , F > (
56
+ f : OpFn < ' _ , N , M , F > ,
57
+ ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; M ] , Error > + ' _
51
58
where
52
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; M ] , Error > ,
53
59
F : Float + FromPrimitive ,
54
60
{
55
61
move |p : & [ F ; N ] | forward_jacobian_const ( p, & f)
56
62
}
57
63
58
64
#[ inline( always) ]
59
- pub fn central_jacobian < const N : usize , const M : usize , Func , F > (
60
- f : Func ,
61
- ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; M ] , Error >
65
+ pub fn central_jacobian < const N : usize , const M : usize , F > (
66
+ f : OpFn < ' _ , N , M , F > ,
67
+ ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; M ] , Error > + ' _
62
68
where
63
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; M ] , Error > ,
64
69
F : Float + FromPrimitive ,
65
70
{
66
71
move |p : & [ F ; N ] | central_jacobian_const ( p, & f)
67
72
}
68
73
69
74
#[ inline( always) ]
70
- pub fn forward_jacobian_vec_prod < const N : usize , const M : usize , Func , F > (
71
- f : Func ,
72
- ) -> impl Fn ( & [ F ; N ] , & [ F ; N ] ) -> Result < [ F ; M ] , Error >
75
+ pub fn forward_jacobian_vec_prod < const N : usize , const M : usize , F > (
76
+ f : OpFn < ' _ , N , M , F > ,
77
+ ) -> impl Fn ( & [ F ; N ] , & [ F ; N ] ) -> Result < [ F ; M ] , Error > + ' _
73
78
where
74
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; M ] , Error > ,
75
79
F : Float + FromPrimitive ,
76
80
{
77
- move |p : & [ F ; N ] , v : & [ F ; N ] | forward_jacobian_vec_prod_const ( p, & f, v)
81
+ move |p : & [ F ; N ] , v : & [ F ; N ] | forward_jacobian_vec_prod_const ( p, f, v)
78
82
}
79
83
80
84
#[ inline( always) ]
81
- pub fn central_jacobian_vec_prod < const N : usize , const M : usize , Func , F > (
82
- f : Func ,
83
- ) -> impl Fn ( & [ F ; N ] , & [ F ; N ] ) -> Result < [ F ; M ] , Error >
85
+ pub fn central_jacobian_vec_prod < const N : usize , const M : usize , F > (
86
+ f : OpFn < ' _ , N , M , F > ,
87
+ ) -> impl Fn ( & [ F ; N ] , & [ F ; N ] ) -> Result < [ F ; M ] , Error > + ' _
84
88
where
85
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; M ] , Error > ,
86
89
F : Float + FromPrimitive ,
87
90
{
88
- move |p : & [ F ; N ] , v : & [ F ; N ] | central_jacobian_vec_prod_const ( p, & f, v)
91
+ move |p : & [ F ; N ] , v : & [ F ; N ] | central_jacobian_vec_prod_const ( p, f, v)
89
92
}
90
93
91
94
#[ inline( always) ]
92
- pub fn forward_jacobian_pert < const N : usize , const M : usize , Func , F > (
93
- f : Func ,
94
- ) -> impl Fn ( & [ F ; N ] , & PerturbationVectors ) -> Result < [ [ F ; N ] ; M ] , Error >
95
+ pub fn forward_jacobian_pert < const N : usize , const M : usize , F > (
96
+ f : OpFn < ' _ , N , M , F > ,
97
+ ) -> impl Fn ( & [ F ; N ] , & PerturbationVectors ) -> Result < [ [ F ; N ] ; M ] , Error > + ' _
95
98
where
96
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; M ] , Error > ,
97
99
F : Float + FromPrimitive + AddAssign ,
98
100
{
99
- move |p : & [ F ; N ] , pert : & PerturbationVectors | forward_jacobian_pert_const ( p, & f, pert)
101
+ move |p : & [ F ; N ] , pert : & PerturbationVectors | forward_jacobian_pert_const ( p, f, pert)
100
102
}
101
103
102
104
#[ inline( always) ]
103
- pub fn central_jacobian_pert < const N : usize , const M : usize , Func , F > (
104
- f : Func ,
105
- ) -> impl Fn ( & [ F ; N ] , & PerturbationVectors ) -> Result < [ [ F ; N ] ; M ] , Error >
105
+ pub fn central_jacobian_pert < const N : usize , const M : usize , F > (
106
+ f : OpFn < ' _ , N , M , F > ,
107
+ ) -> impl Fn ( & [ F ; N ] , & PerturbationVectors ) -> Result < [ [ F ; N ] ; M ] , Error > + ' _
106
108
where
107
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; M ] , Error > ,
108
109
F : Float + FromPrimitive + AddAssign ,
109
110
{
110
- move |p : & [ F ; N ] , pert : & PerturbationVectors | central_jacobian_pert_const ( p, & f, pert)
111
+ move |p : & [ F ; N ] , pert : & PerturbationVectors | central_jacobian_pert_const ( p, f, pert)
111
112
}
112
113
113
114
#[ inline( always) ]
114
- pub fn forward_hessian < const N : usize , Func , F > (
115
- f : Func ,
116
- ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; N ] , Error >
115
+ pub fn forward_hessian < const N : usize , F > (
116
+ f : GradientFn < ' _ , N , F > ,
117
+ ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; N ] , Error > + ' _
117
118
where
118
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error > ,
119
119
F : Float + FromPrimitive ,
120
120
{
121
- move |p : & [ F ; N ] | forward_hessian_const ( p, & f)
121
+ move |p : & [ F ; N ] | forward_hessian_const ( p, f)
122
122
}
123
123
124
124
#[ inline( always) ]
125
- pub fn central_hessian < const N : usize , Func , F > (
126
- f : Func ,
127
- ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; N ] , Error >
125
+ pub fn central_hessian < const N : usize , F > (
126
+ f : GradientFn < ' _ , N , F > ,
127
+ ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; N ] , Error > + ' _
128
128
where
129
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error > ,
130
129
F : Float + FromPrimitive ,
131
130
{
132
- move |p : & [ F ; N ] | central_hessian_const ( p, & f)
131
+ move |p : & [ F ; N ] | central_hessian_const ( p, f)
133
132
}
134
133
135
134
#[ inline( always) ]
136
- pub fn forward_hessian_vec_prod < const N : usize , Func , F > (
137
- f : Func ,
138
- ) -> impl Fn ( & [ F ; N ] , & [ F ; N ] ) -> Result < [ F ; N ] , Error >
135
+ pub fn forward_hessian_vec_prod < const N : usize , F > (
136
+ f : GradientFn < ' _ , N , F > ,
137
+ ) -> impl Fn ( & [ F ; N ] , & [ F ; N ] ) -> Result < [ F ; N ] , Error > + ' _
139
138
where
140
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error > ,
141
139
F : Float + FromPrimitive ,
142
140
{
143
- move |p : & [ F ; N ] , v : & [ F ; N ] | forward_hessian_vec_prod_const ( p, & f, v)
141
+ move |p : & [ F ; N ] , v : & [ F ; N ] | forward_hessian_vec_prod_const ( p, f, v)
144
142
}
145
143
146
144
#[ inline( always) ]
147
- pub fn central_hessian_vec_prod < const N : usize , Func , F > (
148
- f : Func ,
149
- ) -> impl Fn ( & [ F ; N ] , & [ F ; N ] ) -> Result < [ F ; N ] , Error >
145
+ pub fn central_hessian_vec_prod < const N : usize , F > (
146
+ f : GradientFn < ' _ , N , F > ,
147
+ ) -> impl Fn ( & [ F ; N ] , & [ F ; N ] ) -> Result < [ F ; N ] , Error > + ' _
150
148
where
151
- Func : Fn ( & [ F ; N ] ) -> Result < [ F ; N ] , Error > ,
152
149
F : Float + FromPrimitive ,
153
150
{
154
- move |p : & [ F ; N ] , v : & [ F ; N ] | central_hessian_vec_prod_const ( p, & f, v)
151
+ move |p : & [ F ; N ] , v : & [ F ; N ] | central_hessian_vec_prod_const ( p, f, v)
155
152
}
156
153
157
154
#[ inline( always) ]
158
- pub fn forward_hessian_nograd < const N : usize , Func , F > (
159
- f : Func ,
160
- ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; N ] , Error >
155
+ pub fn forward_hessian_nograd < const N : usize , F > (
156
+ f : CostFn < ' _ , N , F > ,
157
+ ) -> impl Fn ( & [ F ; N ] ) -> Result < [ [ F ; N ] ; N ] , Error > + ' _
161
158
where
162
- Func : Fn ( & [ F ; N ] ) -> Result < F , Error > ,
163
159
F : Float + FromPrimitive + AddAssign ,
164
160
{
165
- move |p : & [ F ; N ] | forward_hessian_nograd_const ( p, & f)
161
+ move |p : & [ F ; N ] | forward_hessian_nograd_const ( p, f)
166
162
}
167
163
168
164
#[ inline( always) ]
169
- pub fn forward_hessian_nograd_sparse < const N : usize , Func , F > (
170
- f : Func ,
171
- ) -> impl Fn ( & [ F ; N ] , Vec < [ usize ; 2 ] > ) -> Result < [ [ F ; N ] ; N ] , Error >
165
+ pub fn forward_hessian_nograd_sparse < const N : usize , F > (
166
+ f : CostFn < ' _ , N , F > ,
167
+ ) -> impl Fn ( & [ F ; N ] , Vec < [ usize ; 2 ] > ) -> Result < [ [ F ; N ] ; N ] , Error > + ' _
172
168
where
173
- Func : Fn ( & [ F ; N ] ) -> Result < F , Error > ,
174
169
F : Float + FromPrimitive + AddAssign ,
175
170
{
176
- move |p : & [ F ; N ] , indices : Vec < [ usize ; 2 ] > | forward_hessian_nograd_sparse_const ( p, & f, indices)
171
+ move |p : & [ F ; N ] , indices : Vec < [ usize ; 2 ] > | forward_hessian_nograd_sparse_const ( p, f, indices)
177
172
}
178
173
179
174
#[ cfg( test) ]
@@ -267,7 +262,7 @@ mod tests {
267
262
268
263
#[ test]
269
264
fn test_forward_diff_func ( ) {
270
- let grad = forward_diff ( f1) ;
265
+ let grad = forward_diff ( & f1) ;
271
266
let out = grad ( & x1 ( ) ) . unwrap ( ) ;
272
267
let res = [ 1.0 , 2.0 ] ;
273
268
@@ -287,7 +282,7 @@ mod tests {
287
282
288
283
#[ test]
289
284
fn test_central_diff_func ( ) {
290
- let grad = central_diff ( f1) ;
285
+ let grad = central_diff ( & f1) ;
291
286
let out = grad ( & x1 ( ) ) . unwrap ( ) ;
292
287
let res = [ 1.0f64 , 2.0 ] ;
293
288
@@ -296,7 +291,7 @@ mod tests {
296
291
}
297
292
298
293
let p = [ 1.0f64 , 2.0f64 ] ;
299
- let grad = central_diff ( f1) ;
294
+ let grad = central_diff ( & f1) ;
300
295
let out = grad ( & p) . unwrap ( ) ;
301
296
let res = [ 1.0f64 , 4.0 ] ;
302
297
0 commit comments