Skip to content

Commit 35e00e2

Browse files
committed
finitediff: Less complex types
1 parent 1d4a3fb commit 35e00e2

File tree

12 files changed

+236
-250
lines changed

12 files changed

+236
-250
lines changed

crates/finitediff/src/array/diff.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ use num::FromPrimitive;
1111

1212
use crate::utils::mod_and_calc_const;
1313

14+
use super::CostFn;
15+
1416
pub fn forward_diff_const<const N: usize, F>(
1517
x: &[F; N],
16-
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
18+
f: CostFn<'_, N, F>,
1719
) -> Result<[F; N], Error>
1820
where
1921
F: Float + FromPrimitive,
@@ -35,7 +37,7 @@ where
3537

3638
pub fn central_diff_const<const N: usize, F>(
3739
x: &[F; N],
38-
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
40+
f: CostFn<'_, N, F>,
3941
) -> Result<[F; N], Error>
4042
where
4143
F: Float + FromPrimitive,

crates/finitediff/src/array/hessian.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ use num::{Float, FromPrimitive};
1212

1313
use crate::utils::{mod_and_calc, restore_symmetry_const, KV};
1414

15+
use super::{CostFn, GradientFn};
16+
1517
pub fn forward_hessian_const<const N: usize, F>(
1618
x: &[F; N],
17-
grad: &dyn Fn(&[F; N]) -> Result<[F; N], Error>,
19+
grad: GradientFn<'_, N, F>,
1820
) -> Result<[[F; N]; N], Error>
1921
where
2022
F: Float + FromPrimitive,
@@ -36,7 +38,7 @@ where
3638

3739
pub fn central_hessian_const<const N: usize, F>(
3840
x: &[F; N],
39-
grad: &dyn Fn(&[F; N]) -> Result<[F; N], Error>,
41+
grad: GradientFn<'_, N, F>,
4042
) -> Result<[[F; N]; N], Error>
4143
where
4244
F: Float + FromPrimitive,
@@ -59,7 +61,7 @@ where
5961

6062
pub fn forward_hessian_vec_prod_const<const N: usize, F>(
6163
x: &[F; N],
62-
grad: &dyn Fn(&[F; N]) -> Result<[F; N], Error>,
64+
grad: GradientFn<'_, N, F>,
6365
p: &[F; N],
6466
) -> Result<[F; N], Error>
6567
where
@@ -83,7 +85,7 @@ where
8385

8486
pub fn central_hessian_vec_prod_const<const N: usize, F>(
8587
x: &[F; N],
86-
grad: &dyn Fn(&[F; N]) -> Result<[F; N], Error>,
88+
grad: GradientFn<'_, N, F>,
8789
p: &[F; N],
8890
) -> Result<[F; N], Error>
8991
where
@@ -108,7 +110,7 @@ where
108110

109111
pub fn forward_hessian_nograd_const<const N: usize, F>(
110112
x: &[F; N],
111-
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
113+
f: CostFn<'_, N, F>,
112114
) -> Result<[[F; N]; N], Error>
113115
where
114116
F: Float + FromPrimitive + AddAssign,
@@ -149,7 +151,7 @@ where
149151

150152
pub fn forward_hessian_nograd_sparse_const<const N: usize, F>(
151153
x: &[F; N],
152-
f: &dyn Fn(&[F; N]) -> Result<F, Error>,
154+
f: CostFn<'_, N, F>,
153155
indices: Vec<[usize; 2]>,
154156
) -> Result<[[F; N]; N], Error>
155157
where

crates/finitediff/src/array/jacobian.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ use num::{Float, FromPrimitive};
1313
use crate::pert::PerturbationVectors;
1414
use crate::utils::{mod_and_calc, mod_and_calc_const};
1515

16+
use super::OpFn;
17+
1618
pub fn forward_jacobian_const<const N: usize, const M: usize, F>(
1719
x: &[F; N],
18-
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
20+
fs: OpFn<'_, N, M, F>,
1921
) -> Result<[[F; N]; M], Error>
2022
where
2123
F: Float + FromPrimitive,
@@ -37,7 +39,7 @@ where
3739

3840
pub fn central_jacobian_const<const N: usize, const M: usize, F>(
3941
x: &[F; N],
40-
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
42+
fs: OpFn<'_, N, M, F>,
4143
) -> Result<[[F; N]; M], Error>
4244
where
4345
F: Float + FromPrimitive,
@@ -58,7 +60,7 @@ where
5860

5961
pub fn forward_jacobian_vec_prod_const<const N: usize, const M: usize, F>(
6062
x: &[F; N],
61-
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
63+
fs: OpFn<'_, N, M, F>,
6264
p: &[F; N],
6365
) -> Result<[F; M], Error>
6466
where
@@ -85,7 +87,7 @@ where
8587

8688
pub fn central_jacobian_vec_prod_const<const N: usize, const M: usize, F>(
8789
x: &[F; N],
88-
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
90+
fs: OpFn<'_, N, M, F>,
8991
p: &[F; N],
9092
) -> Result<[F; M], Error>
9193
where
@@ -117,7 +119,7 @@ where
117119

118120
pub fn forward_jacobian_pert_const<const N: usize, const M: usize, F>(
119121
x: &[F; N],
120-
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
122+
fs: OpFn<'_, N, M, F>,
121123
pert: &PerturbationVectors,
122124
) -> Result<[[F; N]; M], Error>
123125
where
@@ -149,7 +151,7 @@ where
149151

150152
pub fn central_jacobian_pert_const<const N: usize, const M: usize, F>(
151153
x: &[F; N],
152-
fs: &dyn Fn(&[F; N]) -> Result<[F; M], Error>,
154+
fs: OpFn<'_, N, M, F>,
153155
pert: &PerturbationVectors,
154156
) -> Result<[[F; N]; M], Error>
155157
where

crates/finitediff/src/array/mod.rs

Lines changed: 60 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -26,154 +26,149 @@ use jacobian::{
2626
forward_jacobian_const, forward_jacobian_pert_const, forward_jacobian_vec_prod_const,
2727
};
2828

29+
pub(crate) type CostFn<'a, const N: usize, F> = &'a dyn Fn(&[F; N]) -> Result<F, Error>;
30+
pub(crate) type GradientFn<'a, const N: usize, F> = &'a dyn Fn(&[F; N]) -> Result<[F; N], Error>;
31+
pub(crate) type OpFn<'a, const N: usize, const M: usize, F> =
32+
&'a dyn Fn(&[F; N]) -> Result<[F; M], Error>;
33+
2934
#[inline(always)]
30-
pub fn forward_diff<const N: usize, Func, F>(f: Func) -> impl Fn(&[F; N]) -> Result<[F; N], Error>
35+
pub fn forward_diff<const N: usize, F>(
36+
f: CostFn<'_, N, F>,
37+
) -> impl Fn(&[F; N]) -> Result<[F; N], Error> + '_
3138
where
32-
Func: Fn(&[F; N]) -> Result<F, Error>,
3339
F: Float + FromPrimitive,
3440
{
3541
move |p: &[F; N]| forward_diff_const(p, &f)
3642
}
3743

3844
#[inline(always)]
39-
pub fn central_diff<const N: usize, Func, F>(f: Func) -> impl Fn(&[F; N]) -> Result<[F; N], Error>
45+
pub fn central_diff<const N: usize, F>(
46+
f: CostFn<'_, N, F>,
47+
) -> impl Fn(&[F; N]) -> Result<[F; N], Error> + '_
4048
where
41-
Func: Fn(&[F; N]) -> Result<F, Error>,
4249
F: Float + FromPrimitive,
4350
{
4451
move |p: &[F; N]| central_diff_const(p, &f)
4552
}
4653

4754
#[inline(always)]
48-
pub fn forward_jacobian<const N: usize, const M: usize, Func, F>(
49-
f: Func,
50-
) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error>
55+
pub fn forward_jacobian<const N: usize, const M: usize, F>(
56+
f: OpFn<'_, N, M, F>,
57+
) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error> + '_
5158
where
52-
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
5359
F: Float + FromPrimitive,
5460
{
5561
move |p: &[F; N]| forward_jacobian_const(p, &f)
5662
}
5763

5864
#[inline(always)]
59-
pub fn central_jacobian<const N: usize, const M: usize, Func, F>(
60-
f: Func,
61-
) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error>
65+
pub fn central_jacobian<const N: usize, const M: usize, F>(
66+
f: OpFn<'_, N, M, F>,
67+
) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error> + '_
6268
where
63-
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
6469
F: Float + FromPrimitive,
6570
{
6671
move |p: &[F; N]| central_jacobian_const(p, &f)
6772
}
6873

6974
#[inline(always)]
70-
pub fn forward_jacobian_vec_prod<const N: usize, const M: usize, Func, F>(
71-
f: Func,
72-
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error>
75+
pub fn forward_jacobian_vec_prod<const N: usize, const M: usize, F>(
76+
f: OpFn<'_, N, M, F>,
77+
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error> + '_
7378
where
74-
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
7579
F: Float + FromPrimitive,
7680
{
77-
move |p: &[F; N], v: &[F; N]| forward_jacobian_vec_prod_const(p, &f, v)
81+
move |p: &[F; N], v: &[F; N]| forward_jacobian_vec_prod_const(p, f, v)
7882
}
7983

8084
#[inline(always)]
81-
pub fn central_jacobian_vec_prod<const N: usize, const M: usize, Func, F>(
82-
f: Func,
83-
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error>
85+
pub fn central_jacobian_vec_prod<const N: usize, const M: usize, F>(
86+
f: OpFn<'_, N, M, F>,
87+
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error> + '_
8488
where
85-
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
8689
F: Float + FromPrimitive,
8790
{
88-
move |p: &[F; N], v: &[F; N]| central_jacobian_vec_prod_const(p, &f, v)
91+
move |p: &[F; N], v: &[F; N]| central_jacobian_vec_prod_const(p, f, v)
8992
}
9093

9194
#[inline(always)]
92-
pub fn forward_jacobian_pert<const N: usize, const M: usize, Func, F>(
93-
f: Func,
94-
) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error>
95+
pub fn forward_jacobian_pert<const N: usize, const M: usize, F>(
96+
f: OpFn<'_, N, M, F>,
97+
) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error> + '_
9598
where
96-
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
9799
F: Float + FromPrimitive + AddAssign,
98100
{
99-
move |p: &[F; N], pert: &PerturbationVectors| forward_jacobian_pert_const(p, &f, pert)
101+
move |p: &[F; N], pert: &PerturbationVectors| forward_jacobian_pert_const(p, f, pert)
100102
}
101103

102104
#[inline(always)]
103-
pub fn central_jacobian_pert<const N: usize, const M: usize, Func, F>(
104-
f: Func,
105-
) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error>
105+
pub fn central_jacobian_pert<const N: usize, const M: usize, F>(
106+
f: OpFn<'_, N, M, F>,
107+
) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error> + '_
106108
where
107-
Func: Fn(&[F; N]) -> Result<[F; M], Error>,
108109
F: Float + FromPrimitive + AddAssign,
109110
{
110-
move |p: &[F; N], pert: &PerturbationVectors| central_jacobian_pert_const(p, &f, pert)
111+
move |p: &[F; N], pert: &PerturbationVectors| central_jacobian_pert_const(p, f, pert)
111112
}
112113

113114
#[inline(always)]
114-
pub fn forward_hessian<const N: usize, Func, F>(
115-
f: Func,
116-
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error>
115+
pub fn forward_hessian<const N: usize, F>(
116+
f: GradientFn<'_, N, F>,
117+
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
117118
where
118-
Func: Fn(&[F; N]) -> Result<[F; N], Error>,
119119
F: Float + FromPrimitive,
120120
{
121-
move |p: &[F; N]| forward_hessian_const(p, &f)
121+
move |p: &[F; N]| forward_hessian_const(p, f)
122122
}
123123

124124
#[inline(always)]
125-
pub fn central_hessian<const N: usize, Func, F>(
126-
f: Func,
127-
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error>
125+
pub fn central_hessian<const N: usize, F>(
126+
f: GradientFn<'_, N, F>,
127+
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
128128
where
129-
Func: Fn(&[F; N]) -> Result<[F; N], Error>,
130129
F: Float + FromPrimitive,
131130
{
132-
move |p: &[F; N]| central_hessian_const(p, &f)
131+
move |p: &[F; N]| central_hessian_const(p, f)
133132
}
134133

135134
#[inline(always)]
136-
pub fn forward_hessian_vec_prod<const N: usize, Func, F>(
137-
f: Func,
138-
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error>
135+
pub fn forward_hessian_vec_prod<const N: usize, F>(
136+
f: GradientFn<'_, N, F>,
137+
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error> + '_
139138
where
140-
Func: Fn(&[F; N]) -> Result<[F; N], Error>,
141139
F: Float + FromPrimitive,
142140
{
143-
move |p: &[F; N], v: &[F; N]| forward_hessian_vec_prod_const(p, &f, v)
141+
move |p: &[F; N], v: &[F; N]| forward_hessian_vec_prod_const(p, f, v)
144142
}
145143

146144
#[inline(always)]
147-
pub fn central_hessian_vec_prod<const N: usize, Func, F>(
148-
f: Func,
149-
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error>
145+
pub fn central_hessian_vec_prod<const N: usize, F>(
146+
f: GradientFn<'_, N, F>,
147+
) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error> + '_
150148
where
151-
Func: Fn(&[F; N]) -> Result<[F; N], Error>,
152149
F: Float + FromPrimitive,
153150
{
154-
move |p: &[F; N], v: &[F; N]| central_hessian_vec_prod_const(p, &f, v)
151+
move |p: &[F; N], v: &[F; N]| central_hessian_vec_prod_const(p, f, v)
155152
}
156153

157154
#[inline(always)]
158-
pub fn forward_hessian_nograd<const N: usize, Func, F>(
159-
f: Func,
160-
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error>
155+
pub fn forward_hessian_nograd<const N: usize, F>(
156+
f: CostFn<'_, N, F>,
157+
) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
161158
where
162-
Func: Fn(&[F; N]) -> Result<F, Error>,
163159
F: Float + FromPrimitive + AddAssign,
164160
{
165-
move |p: &[F; N]| forward_hessian_nograd_const(p, &f)
161+
move |p: &[F; N]| forward_hessian_nograd_const(p, f)
166162
}
167163

168164
#[inline(always)]
169-
pub fn forward_hessian_nograd_sparse<const N: usize, Func, F>(
170-
f: Func,
171-
) -> impl Fn(&[F; N], Vec<[usize; 2]>) -> Result<[[F; N]; N], Error>
165+
pub fn forward_hessian_nograd_sparse<const N: usize, F>(
166+
f: CostFn<'_, N, F>,
167+
) -> impl Fn(&[F; N], Vec<[usize; 2]>) -> Result<[[F; N]; N], Error> + '_
172168
where
173-
Func: Fn(&[F; N]) -> Result<F, Error>,
174169
F: Float + FromPrimitive + AddAssign,
175170
{
176-
move |p: &[F; N], indices: Vec<[usize; 2]>| forward_hessian_nograd_sparse_const(p, &f, indices)
171+
move |p: &[F; N], indices: Vec<[usize; 2]>| forward_hessian_nograd_sparse_const(p, f, indices)
177172
}
178173

179174
#[cfg(test)]
@@ -267,7 +262,7 @@ mod tests {
267262

268263
#[test]
269264
fn test_forward_diff_func() {
270-
let grad = forward_diff(f1);
265+
let grad = forward_diff(&f1);
271266
let out = grad(&x1()).unwrap();
272267
let res = [1.0, 2.0];
273268

@@ -287,7 +282,7 @@ mod tests {
287282

288283
#[test]
289284
fn test_central_diff_func() {
290-
let grad = central_diff(f1);
285+
let grad = central_diff(&f1);
291286
let out = grad(&x1()).unwrap();
292287
let res = [1.0f64, 2.0];
293288

@@ -296,7 +291,7 @@ mod tests {
296291
}
297292

298293
let p = [1.0f64, 2.0f64];
299-
let grad = central_diff(f1);
294+
let grad = central_diff(&f1);
300295
let out = grad(&p).unwrap();
301296
let res = [1.0f64, 4.0];
302297

crates/finitediff/src/ndarr/diff.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@ use num::{Float, FromPrimitive};
1010

1111
use crate::utils::*;
1212

13+
use super::CostFn;
14+
1315
pub fn forward_diff_ndarray<F>(
1416
x: &ndarray::Array1<F>,
15-
f: &dyn Fn(&ndarray::Array1<F>) -> Result<F, Error>,
17+
f: CostFn<'_, F>,
1618
) -> Result<ndarray::Array1<F>, Error>
1719
where
1820
F: Float,
@@ -31,7 +33,7 @@ where
3133

3234
pub fn central_diff_ndarray<F>(
3335
x: &ndarray::Array1<F>,
34-
f: &dyn Fn(&ndarray::Array1<F>) -> Result<F, Error>,
36+
f: CostFn<'_, F>,
3537
) -> Result<ndarray::Array1<F>, Error>
3638
where
3739
F: Float + FromPrimitive,

0 commit comments

Comments
 (0)