Skip to content

Commit 5c5212f

Browse files
authored
Improving Spectral (#53)
* spitballing * my interval was way wrong * updated spectral-derivatives such that some of the checks I had in this PR are no longer necessary * added some road signs * updated to spectral-derivatives 0.7 * this test code encourages not-thorough testing * fixed caching test * added tests for chebyshev case and removed caching check for spectral because evidently unnecessary * added Chebyshev examples * removed print statement * bumped version number * added axis param to tests so they pass now * improved docstring to account for the fact filtering in chebyshev basis is a little more subtle and bumped spectral-derivatives version number to get latest features * added back a newline
1 parent 7aa93ce commit 5c5212f

File tree

6 files changed

+147
-121
lines changed

6 files changed

+147
-121
lines changed

derivative/dglobal.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,50 @@
77
from scipy import interpolate, sparse
88
from scipy.special import legendre
99
from sklearn.linear_model import Lasso
10+
from specderiv import cheb_deriv, fourier_deriv
1011

1112

1213
@register("spectral")
1314
class Spectral(Derivative):
14-
def __init__(self, **kwargs):
15+
def __init__(self, order=1, axis=0, basis='fourier', filter=None):
1516
"""
16-
Compute the numerical derivative by first computing the FFT. In Fourier space, derivatives are multiplication
17-
by i*phase; compute the IFFT after.
18-
19-
Args:
20-
**kwargs: Optional keyword arguments.
21-
17+
Compute the numerical derivative by spectral methods. In Fourier space, derivatives are multiplication
18+
by i*phase; compute the inverse transform after. Use either Fourier modes of Chebyshev polynomials as
19+
the basis.
20+
2221
Keyword Args:
23-
filter: Optional. A function that takes in frequencies and outputs weights to scale the coefficient at
24-
the input frequency in Fourier space. Input frequencies are the discrete fourier transform sample
25-
frequencies associated with the domain variable. Look into python signal processing resources in
26-
scipy.signal for common filters.
27-
22+
order (int): order of the derivative, defaults to 1st order
23+
axis (int): the dimension of the data along which to differentiate, defaults to first dimension
24+
basis (str): 'fourier' or 'chebyshev', the set of basis functions to use for differentiation
25+
Note `basis='fourier'` assumes your function is periodic and sampled over a period of its domain,
26+
[a, b), and `basis='chebyshev'` assumes your function is sampled at cosine-spaced points on the
27+
domain [a, b].
28+
filter: Optional. A function that takes in basis function indices and outputs weights, which scale
29+
the corresponding modes in the basis-space interpolation before derivatives are taken, e.g.
30+
`lambda k: k<10` will keep only the first ten modes. With the Fourier basis, k corresponds to
31+
wavenumbers, so common filters from scipy.signal can be used. In the Chebyshev basis, modes do
32+
not directly correspond to frequencies, so high frequency noise can not be separated quite as
33+
cleanly, however it still may be helpful to dampen higher modes.
2834
"""
29-
# Filter function. Default: Identity filter
30-
self.filter = kwargs.get('filter', np.vectorize(lambda f: 1))
31-
self._x_hat = None
32-
self._freq = None
33-
34-
def _dglobal(self, t, x):
35-
self._x_hat = np.fft.fft(x)
36-
self._freq = np.fft.fftfreq(t.size, d=(t[1] - t[0]))
35+
self.order = order
36+
self.axis = axis
37+
if basis not in ['chebyshev', 'fourier']:
38+
raise ValueError("Only chebyshev and fourier bases are allowed.")
39+
self.basis = basis
40+
self.filter = filter
41+
42+
@_memoize_arrays(1) # the memoization is 1 deep, as in only remembers the most recent args
43+
def _global(self, t, x):
44+
if self.basis == 'chebyshev':
45+
return cheb_deriv(x, t, self.order, self.axis, self.filter)
46+
else: # self.basis == 'fourier'
47+
return fourier_deriv(x, t, self.order, self.axis, self.filter)
3748

3849
def compute(self, t, x, i):
3950
return next(self.compute_for(t, x, [i]))
4051

4152
def compute_for(self, t, x, indices):
42-
self._dglobal(t, x)
43-
res = np.fft.ifft(1j * 2 * np.pi * self._freq * self.filter(self._freq) * self._x_hat).real
53+
res = self._global(t, x) # cached
4454
for i in indices:
4555
yield res[i]
4656

@@ -212,7 +222,6 @@ def __init__(self, alpha=None):
212222
"""
213223
self.alpha = alpha
214224

215-
216225
@_memoize_arrays(1)
217226
def _global(self, t, z, alpha):
218227
if alpha is None:

derivative/differentiation.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _gen_method(x, t, kind, axis, **kwargs):
2121
return methods.get(kind)(**kwargs)
2222

2323

24-
def dxdt(x, t, kind=None, axis=1, **kwargs):
24+
def dxdt(x, t, kind=None, axis=0, **kwargs):
2525
"""
2626
Compute the derivative of x with respect to t along axis using the numerical derivative specified by "kind". This is
2727
the functional interface of the Derivative class.
@@ -35,7 +35,7 @@ def dxdt(x, t, kind=None, axis=1, **kwargs):
3535
x (:obj:`ndarray` of float): Ordered measurement values.
3636
t (:obj:`ndarray` of float): Ordered measurement times.
3737
kind (string): Derivative method name (see available kinds).
38-
axis ({0,1}): Axis of x along which to differentiate. Default 1.
38+
axis ({0,1}): Axis of x along which to differentiate. Default 0.
3939
**kwargs: Keyword arguments for the derivative method "kind".
4040
4141
Available kinds
@@ -56,7 +56,7 @@ def dxdt(x, t, kind=None, axis=1, **kwargs):
5656
return method.d(x, t, axis=axis)
5757

5858

59-
def smooth_x(x, t, kind=None, axis=1, **kwargs):
59+
def smooth_x(x, t, kind=None, axis=0, **kwargs):
6060
"""
6161
Compute the smoothed version of x given t along axis using the numerical
6262
derivative specified by "kind". This is the functional interface of
@@ -71,7 +71,7 @@ def smooth_x(x, t, kind=None, axis=1, **kwargs):
7171
x (:obj:`ndarray` of float): Ordered measurement values.
7272
t (:obj:`ndarray` of float): Ordered measurement times.
7373
kind (string): Derivative method name (see available kinds).
74-
axis ({0,1}): Axis of x along which to differentiate. Default 1.
74+
axis ({0,1}): Axis of x along which to differentiate. Default 0.
7575
**kwargs: Keyword arguments for the derivative method "kind".
7676
7777
Available kinds
@@ -100,7 +100,7 @@ def compute(self, t, x, i):
100100
"""
101101
Compute the derivative of one-dimensional data x with respect to t at the index i of x, (dx/dt)[i].
102102
103-
Computation of a derivative should fail explicitely if the implementation is unable to compute a derivative at
103+
Computation of a derivative should fail explicitly if the implementation is unable to compute a derivative at
104104
the desired index. Used for global differentiation methods, for example.
105105
106106
This requires that x and t have equal lengths >= 2, and that the index i is a valid index.
@@ -174,7 +174,7 @@ def compute_x_for(self, t, x, indices):
174174
for i in indices:
175175
yield self.compute_x(t, x, i)
176176

177-
def d(self, X, t, axis=1):
177+
def d(self, X, t, axis=0):
178178
"""
179179
Compute the derivative of measurements X taken at times t.
180180
@@ -184,7 +184,7 @@ def d(self, X, t, axis=1):
184184
Args:
185185
X (:obj:`ndarray` of float): Ordered measurements values. Multiple measurements allowed.
186186
t (:obj:`ndarray` of float): Ordered measurement times.
187-
axis ({0,1}). axis of X along which to differentiate. default 1.
187+
axis ({0,1}). axis of X along which to differentiate. default 0.
188188
189189
Returns:
190190
:obj:`ndarray` of float: Returns dX/dt along axis.
@@ -202,7 +202,7 @@ def d(self, X, t, axis=1):
202202
return _restore_axes(dX, axis, flat)
203203

204204

205-
def x(self, X, t, axis=1):
205+
def x(self, X, t, axis=0):
206206
"""
207207
Compute the smoothed X values from measurements X taken at times t.
208208
@@ -212,7 +212,7 @@ def x(self, X, t, axis=1):
212212
Args:
213213
X (:obj:`ndarray` of float): Ordered measurements values. Multiple measurements allowed.
214214
t (:obj:`ndarray` of float): Ordered measurement times.
215-
axis ({0,1}). axis of X along which to smooth. default 1.
215+
axis ({0,1}). axis of X along which to smooth. default 0.
216216
217217
Returns:
218218
:obj:`ndarray` of float: Returns dX/dt along axis.
@@ -228,6 +228,8 @@ def x(self, X, t, axis=1):
228228

229229

230230
def _align_axes(X, t, axis) -> tuple[NDArray, tuple[int, ...]]:
231+
"""Reshapes the data so the derivative always happens along axis 1.
232+
"""
231233
X = np.array(X)
232234
orig_shape = X.shape
233235
# By convention, differentiate axis 1
@@ -244,6 +246,8 @@ def _align_axes(X, t, axis) -> tuple[NDArray, tuple[int, ...]]:
244246

245247

246248
def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArray:
249+
"""Undo the operation of _align_axes, so data can be returned in its original shape
250+
"""
247251
if len(orig_shape) == 1:
248252
return dX.flatten()
249253
else:

docs/notebooks/Examples.ipynb

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,5 @@
11
{
22
"cells": [
3-
{
4-
"cell_type": "code",
5-
"execution_count": null,
6-
"metadata": {
7-
"ExecuteTime": {
8-
"end_time": "2020-05-25T22:56:27.818221Z",
9-
"start_time": "2020-05-25T22:56:27.413152Z"
10-
},
11-
"tags": []
12-
},
13-
"outputs": [],
14-
"source": [
15-
"%matplotlib inline"
16-
]
17-
},
183
{
194
"cell_type": "code",
205
"execution_count": null,
@@ -65,17 +50,16 @@
6550
"outputs": [],
6651
"source": [
6752
"def plot_example(diff_method, t, data_f, res_f, sigmas, y_label=None):\n",
68-
" '''\n",
69-
" Utility function for concise plotting of examples.\n",
70-
" '''\n",
53+
" '''Utility function for concise plotting of examples.'''\n",
7154
" fig, axes = plt.subplots(1, len(sigmas), figsize=[len(sigmas)*4, 3])\n",
7255
" \n",
7356
" # Compute the derivative\n",
74-
" res = diff_method.d(np.vstack([data_f(t, s) for s in sigmas]), t)\n",
57+
" res = diff_method.d(np.vstack([data_f(t, s) for s in sigmas]), t, axis=1)\n",
7558
" for i, s in enumerate(sigmas):\n",
76-
" axes[i].plot(t, res[i])\n",
7759
" axes[i].plot(t, res_f(t))\n",
78-
" axes[i].set_title(\"Noise: $\\sigma$={}\".format(s))\n",
60+
" axes[i].plot(t, res[i])\n",
61+
" axes[i].set_title(r\"Noise: $\\sigma$={}\".format(s))\n",
62+
" axes[i].set_ylim([-1.25,1.3])\n",
7963
" if y_label:\n",
8064
" axes[0].set_ylabel(y_label, fontsize=12)"
8165
]
@@ -133,7 +117,7 @@
133117
"\n",
134118
"fig,ax = plt.subplots(1, figsize=[5,3])\n",
135119
"kind = FiniteDifference(k=1)\n",
136-
"ax.plot(t, kind.d(x,t))"
120+
"ax.plot(t, kind.d(x,t));"
137121
]
138122
},
139123
{
@@ -159,7 +143,7 @@
159143
"from derivative import dxdt\n",
160144
"\n",
161145
"fig,ax = plt.subplots(1, figsize=[5,3])\n",
162-
"ax.plot(t, dxdt(x, t, \"finite_difference\", k=1))"
146+
"ax.plot(t, dxdt(x, t, \"finite_difference\", k=1));"
163147
]
164148
},
165149
{
@@ -203,10 +187,10 @@
203187
"sigmas = [0, 0.01, 0.1]\n",
204188
"fig, ax = plt.subplots(1, len(sigmas), figsize=[len(sigmas)*4, 3])\n",
205189
"\n",
206-
"t = np.linspace(0, 2*np.pi, 50)\n",
190+
"t = np.linspace(0, 2*np.pi, 50, endpoint=False)\n",
207191
"for axs, s in zip(ax, sigmas): \n",
208192
" axs.scatter(t, noisy_sin(t, s))\n",
209-
" axs.set_title(\"Noise: $\\sigma$={}\".format(s))"
193+
" axs.set_title(r\"Noise: $\\sigma$={}\".format(s))"
210194
]
211195
},
212196
{
@@ -295,7 +279,7 @@
295279
"cell_type": "markdown",
296280
"metadata": {},
297281
"source": [
298-
"### Spectral method\n",
282+
"### Spectral method - Fourier basis\n",
299283
"Add your own filter!"
300284
]
301285
},
@@ -312,12 +296,35 @@
312296
"outputs": [],
313297
"source": [
314298
"no_filter = derivative.Spectral()\n",
315-
"yes_filter = derivative.Spectral(filter=np.vectorize(lambda f: 1 if abs(f) < 0.5 else 0))\n",
299+
"yes_filter = derivative.Spectral(filter=np.vectorize(lambda k: 1 if abs(k) < 3 else 0))\n",
316300
"\n",
317301
"plot_example(no_filter, t, noisy_sin, np.cos, sigmas, 'No filter')\n",
318302
"plot_example(yes_filter, t, noisy_sin, np.cos, sigmas, 'Low-pass filter')"
319303
]
320304
},
305+
{
306+
"cell_type": "markdown",
307+
"metadata": {},
308+
"source": [
309+
"### Spectral method - Chebyshev basis\n",
310+
"\n",
311+
"Now let's do with the Chebyshev basis, which requires cosine-spaced points on [a, b] rather than equispaced points on [a, b)"
312+
]
313+
},
314+
{
315+
"cell_type": "code",
316+
"execution_count": null,
317+
"metadata": {},
318+
"outputs": [],
319+
"source": [
320+
"t_cos = np.cos(np.pi * np.arange(50) / 49) * np.pi + np.pi # choose a = 0, b = 2*pi\n",
321+
"no_filter = derivative.Spectral(basis='chebyshev')\n",
322+
"yes_filter = derivative.Spectral(basis='chebyshev', filter=np.vectorize(lambda k: 1 if abs(k) < 6 else 0))\n",
323+
"\n",
324+
"plot_example(no_filter, t_cos, noisy_sin, np.cos, sigmas, 'No filter')\n",
325+
"plot_example(yes_filter, t_cos, noisy_sin, np.cos, sigmas, 'Low-pass filter')"
326+
]
327+
},
321328
{
322329
"cell_type": "markdown",
323330
"metadata": {},
@@ -394,7 +401,7 @@
394401
"outputs": [],
395402
"source": [
396403
"def noisy_abs(t, sigma):\n",
397-
" '''Sine with gaussian noise.'''\n",
404+
" '''Abs with gaussian noise.'''\n",
398405
" np.random.seed(17)\n",
399406
" return np.abs(t) + np.random.normal(loc=0, scale=sigma, size=x.shape)\n",
400407
"\n",
@@ -406,7 +413,7 @@
406413
"t = np.linspace(-1, 1, 50)\n",
407414
"for axs, s in zip(ax, sigmas): \n",
408415
" axs.scatter(t, noisy_abs(t, s))\n",
409-
" axs.set_title(\"Noise: $\\sigma$={}\".format(s))"
416+
" axs.set_title(r\"Noise: $\\sigma$={}\".format(s))"
410417
]
411418
},
412419
{
@@ -482,7 +489,7 @@
482489
"cell_type": "markdown",
483490
"metadata": {},
484491
"source": [
485-
"### Spectral Method"
492+
"### Spectral method - Fourier basis"
486493
]
487494
},
488495
{
@@ -497,12 +504,33 @@
497504
"outputs": [],
498505
"source": [
499506
"no_filter = derivative.Spectral()\n",
500-
"yes_filter = derivative.Spectral(filter=np.vectorize(lambda f: 1 if abs(f) < 1 else 0))\n",
507+
"yes_filter = derivative.Spectral(filter=np.vectorize(lambda k: 1 if abs(k) < 6 else 0))\n",
501508
"\n",
502509
"plot_example(no_filter, t, noisy_abs, d_abs, sigmas, 'No filter')\n",
503510
"plot_example(yes_filter, t, noisy_abs, d_abs, sigmas, 'Low-pass filter')"
504511
]
505512
},
513+
{
514+
"cell_type": "markdown",
515+
"metadata": {},
516+
"source": [
517+
"### Spectral method - Chebyshev basis"
518+
]
519+
},
520+
{
521+
"cell_type": "code",
522+
"execution_count": null,
523+
"metadata": {},
524+
"outputs": [],
525+
"source": [
526+
"t_cos = np.cos(np.pi * np.arange(50)/49)\n",
527+
"no_filter = derivative.Spectral(basis='chebyshev')\n",
528+
"yes_filter = derivative.Spectral(basis='chebyshev', filter=np.vectorize(lambda k: 1 if abs(k) < 15 else 0))\n",
529+
"\n",
530+
"plot_example(no_filter, t_cos, noisy_abs, d_abs, sigmas, 'No filter')\n",
531+
"plot_example(yes_filter, t_cos, noisy_abs, d_abs, sigmas, 'Low-pass filter')"
532+
]
533+
},
506534
{
507535
"cell_type": "markdown",
508536
"metadata": {},
@@ -549,18 +577,11 @@
549577
"kal = derivative.Kalman(alpha=1.)\n",
550578
"plot_example(kal, t, noisy_abs, d_abs, sigmas, 'alpha: 1.')"
551579
]
552-
},
553-
{
554-
"cell_type": "code",
555-
"execution_count": null,
556-
"metadata": {},
557-
"outputs": [],
558-
"source": []
559580
}
560581
],
561582
"metadata": {
562583
"kernelspec": {
563-
"display_name": "Python 3",
584+
"display_name": "Python 3 (ipykernel)",
564585
"language": "python",
565586
"name": "python3"
566587
},
@@ -574,7 +595,7 @@
574595
"name": "python",
575596
"nbconvert_exporter": "python",
576597
"pygments_lexer": "ipython3",
577-
"version": "3.7.10"
598+
"version": "3.13.1"
578599
},
579600
"latex_envs": {
580601
"LaTeX_envs_menu_present": true,

0 commit comments

Comments
 (0)