Skip to content

Commit 5841216

Browse files
committed
Periodogram(freqs)
1 parent d005f94 commit 5841216

File tree

3 files changed

+168
-13
lines changed

3 files changed

+168
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12-
--
12+
- Periodogram(freqs: ArrayLike | None = None) is added to set fixed user-defined frequency grids
1313

1414
### Changed
1515

light-curve/src/features.rs

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ use crate::transform::{StockTransformer, parse_transform};
88
use const_format::formatcp;
99
use conv::ConvUtil;
1010
use itertools::Itertools;
11-
use light_curve_feature::{self as lcf, DataSample, prelude::*};
11+
use light_curve_feature::{self as lcf, DataSample, periodogram::FreqGrid, prelude::*};
1212
use macro_const::macro_const;
1313
use ndarray::IntoNdProducer;
14+
use num_traits::Zero;
1415
use numpy::prelude::*;
15-
use numpy::{PyArray1, PyUntypedArray};
16+
use numpy::{AllowTypeChange, PyArray1, PyArrayLike1, PyUntypedArray};
1617
use once_cell::sync::OnceCell;
1718
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
1819
use pyo3::prelude::*;
@@ -21,7 +22,6 @@ use rayon::prelude::*;
2122
use serde::{Deserialize, Serialize};
2223
use std::collections::HashMap;
2324
use std::convert::TryInto;
24-
2525
// Details of pickle support implementation
2626
// ----------------------------------------
2727
// [PyFeatureEvaluator] implements __getstate__ and __setstate__ required for pickle serialisation,
@@ -1600,6 +1600,7 @@ impl Periodogram {
16001600
resolution: Option<f32>,
16011601
max_freq_factor: Option<f32>,
16021602
nyquist: Option<NyquistArgumentOfPeriodogram>,
1603+
freqs: Option<Bound<PyAny>>,
16031604
fast: Option<bool>,
16041605
features: Option<Bound<PyAny>>,
16051606
) -> PyResult<(LcfPeriodogram<f32>, LcfPeriodogram<f64>)> {
@@ -1638,22 +1639,90 @@ impl Periodogram {
16381639
eval_f32.set_nyquist(nyquist_freq);
16391640
eval_f64.set_nyquist(nyquist_freq);
16401641
}
1641-
if let Some(fast) = fast {
1642-
if fast {
1643-
eval_f32.set_periodogram_algorithm(lcf::PeriodogramPowerFft::new().into());
1644-
eval_f64.set_periodogram_algorithm(lcf::PeriodogramPowerFft::new().into());
1645-
} else {
1646-
eval_f32.set_periodogram_algorithm(lcf::PeriodogramPowerDirect {}.into());
1647-
eval_f64.set_periodogram_algorithm(lcf::PeriodogramPowerDirect {}.into());
1642+
1643+
let fast = fast.unwrap_or(false);
1644+
if fast {
1645+
eval_f32.set_periodogram_algorithm(lcf::PeriodogramPowerFft::new().into());
1646+
eval_f64.set_periodogram_algorithm(lcf::PeriodogramPowerFft::new().into());
1647+
} else {
1648+
eval_f32.set_periodogram_algorithm(lcf::PeriodogramPowerDirect {}.into());
1649+
eval_f64.set_periodogram_algorithm(lcf::PeriodogramPowerDirect {}.into());
1650+
}
1651+
1652+
if let Some(freqs) = freqs {
1653+
const STEP_SIZE_TOLLERANCE: f64 = 10.0 * f32::EPSILON as f64;
1654+
1655+
// It is more likely for users to give f64 array
1656+
let freqs_f64 = PyArrayLike1::<f64, AllowTypeChange>::extract_bound(&freqs)?;
1657+
let freqs_f64 = freqs_f64.readonly();
1658+
let freqs_f64 = freqs_f64.as_array();
1659+
let size = freqs_f64.len();
1660+
if size < 2 {
1661+
return Err(PyValueError::new_err("freqs must have at least two values"));
16481662
}
1663+
let first_zero = freqs_f64[0].is_zero();
1664+
if fast && !first_zero {
1665+
return Err(PyValueError::new_err(
1666+
"When Periodogram(freqs=[...], fast=True), freqs[0] must equal 0",
1667+
));
1668+
}
1669+
let len_is_pow2_p1 = (size - 1).is_power_of_two();
1670+
if fast && !len_is_pow2_p1 {
1671+
return Err(PyValueError::new_err(
1672+
"When Periodogram(freqs=[...], fast=True), len(freqs) must be a power of two plus one, e.g. 2**k + 1",
1673+
));
1674+
}
1675+
let step_candidate = freqs_f64[1] - freqs_f64[0];
1676+
// Check if representable as a linear grid
1677+
let freq_grid_f64 = if freqs_f64.iter().tuple_windows().all(|(x1, x2)| {
1678+
let dx = x2 - x1;
1679+
let rel_diff = f64::abs(dx / step_candidate - 1.0);
1680+
rel_diff < STEP_SIZE_TOLLERANCE
1681+
}) {
1682+
if first_zero && len_is_pow2_p1 {
1683+
let log2_size_m1 = (size - 1).ilog2();
1684+
FreqGrid::zero_based_pow2(step_candidate, log2_size_m1)
1685+
} else {
1686+
FreqGrid::linear(freqs_f64[0], step_candidate, size)
1687+
}
1688+
} else if fast {
1689+
return Err(PyValueError::new_err(
1690+
"When Periodogram(freqs=[...], fast=True), freqs must be a linear grid, like np.linspace(0, max_freq, 2**k + 1)",
1691+
));
1692+
} else {
1693+
FreqGrid::from_array(freqs_f64)
1694+
};
1695+
1696+
let freq_grid_f32 = match &freq_grid_f64 {
1697+
FreqGrid::Arbitrary(_) => {
1698+
let freqs_f32 = PyArrayLike1::<f32, AllowTypeChange>::extract_bound(&freqs)?;
1699+
let freqs_f32 = freqs_f32.readonly();
1700+
let freqs_f32 = freqs_f32.as_array();
1701+
FreqGrid::from_array(freqs_f32)
1702+
}
1703+
FreqGrid::Linear(_) => {
1704+
FreqGrid::linear(freqs_f64[0] as f32, step_candidate as f32, size)
1705+
}
1706+
FreqGrid::ZeroBasedPow2(_) => {
1707+
FreqGrid::zero_based_pow2(step_candidate as f32, (size - 1).ilog2())
1708+
}
1709+
_ => {
1710+
panic!("This FreqGrid is not implemented yet")
1711+
}
1712+
};
1713+
1714+
eval_f32.set_freq_grid(freq_grid_f32);
1715+
eval_f64.set_freq_grid(freq_grid_f64);
16491716
}
1717+
16501718
if let Some(features) = features {
16511719
for x in features.try_iter()? {
16521720
let py_feature = x?.downcast::<PyFeatureEvaluator>()?.borrow();
16531721
eval_f32.add_feature(py_feature.feature_evaluator_f32.clone());
16541722
eval_f64.add_feature(py_feature.feature_evaluator_f64.clone());
16551723
}
16561724
}
1725+
16571726
Ok((eval_f32, eval_f64))
16581727
}
16591728

@@ -1688,6 +1757,7 @@ impl Periodogram {
16881757
resolution = LcfPeriodogram::<f64>::default_resolution(),
16891758
max_freq_factor = LcfPeriodogram::<f64>::default_max_freq_factor(),
16901759
nyquist = NyquistArgumentOfPeriodogram::String(String::from("average")),
1760+
freqs = None,
16911761
fast = true,
16921762
features = None,
16931763
transform = None,
@@ -1697,6 +1767,7 @@ impl Periodogram {
16971767
resolution: Option<f32>,
16981768
max_freq_factor: Option<f32>,
16991769
nyquist: Option<NyquistArgumentOfPeriodogram>,
1770+
freqs: Option<Bound<PyAny>>,
17001771
fast: Option<bool>,
17011772
features: Option<Bound<PyAny>>,
17021773
transform: Option<Bound<PyAny>>,
@@ -1706,8 +1777,15 @@ impl Periodogram {
17061777
"transform is not supported by Periodogram, peak-related features are not transformed, but you still may apply transformation for the underlying features",
17071778
));
17081779
}
1709-
let (eval_f32, eval_f64) =
1710-
Self::create_evals(peaks, resolution, max_freq_factor, nyquist, fast, features)?;
1780+
let (eval_f32, eval_f64) = Self::create_evals(
1781+
peaks,
1782+
resolution,
1783+
max_freq_factor,
1784+
nyquist,
1785+
freqs,
1786+
fast,
1787+
features,
1788+
)?;
17111789
Ok((
17121790
Self {
17131791
eval_f32: eval_f32.clone(),
@@ -1758,6 +1836,15 @@ nyquist : str or float or None, optional
17581836
- float: Nyquist frequency is defined by given quantile of time
17591837
intervals between observations
17601838
Default is '{default_nyquist}'
1839+
freqs : array-like or None, optional
1840+
Explicid and fixed frequency grid (angular frequency, radians/time unit).
1841+
If given, `resolution`, `max_freq_factor` and `nyquist` are being
1842+
ignored.
1843+
For `fast=True` the only supported type of the grid is
1844+
np.linspace(0.0, max_freq, 2**k+1), where k is an integer.
1845+
For `fast=False` any grid is accepted, but linear grids, like
1846+
np.linspace(min_freq, max_freq, n), apply some computational
1847+
optimisations.
17611848
fast : bool or None, optional
17621849
Use "Fast" (approximate and FFT-based) or direct periodogram algorithm,
17631850
default is {default_fast}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
import pytest
3+
from numpy.testing import assert_allclose
4+
from scipy.signal import lombscargle
5+
6+
from light_curve.light_curve_ext import Periodogram
7+
8+
9+
def test_vs_lombscargle():
10+
rng = np.random.default_rng(None)
11+
n = 100
12+
13+
t = np.sort(rng.normal(0, 1, n))
14+
m = np.sin(12.3 * t) + 0.2 * rng.normal(0, 1, n)
15+
scipy_y = (m - m.mean()) / m.std(ddof=1)
16+
17+
freq_grids = [
18+
# This one fails with scipy for ZeroDivisionError
19+
# np.linspace(0.0, 100.0, 257), # zero-based, step=2**k+1
20+
np.linspace(1.0, 100.0, 100), # linear
21+
np.geomspace(1.0, 100.0, 100), # arbitrary
22+
]
23+
for freqs in freq_grids:
24+
licu_freqs, licu_power = Periodogram(freqs=freqs, fast=False).freq_power(t, m)
25+
assert_allclose(licu_freqs, freqs)
26+
scipy_power = lombscargle(t, scipy_y, freqs=freqs, precenter=True, normalize=False)
27+
assert_allclose(scipy_power, licu_power)
28+
29+
30+
def test_different_freq_grids():
31+
rng = np.random.default_rng(None)
32+
33+
rng = np.random.default_rng(None)
34+
n = 100
35+
36+
t = np.sort(rng.normal(0, 1, n))
37+
m = np.sin(12.3 * t) + 0.2 * rng.normal(0, 1, n)
38+
39+
base_grid = np.r_[0:100:257j]
40+
base_power = None
41+
42+
freq_grids = [
43+
base_grid, # zero-based, step=2**k+1
44+
np.r_[base_grid, base_grid[-1] + base_grid[1]], # linear
45+
np.r_[base_grid, 200.0], # arbitrary
46+
]
47+
for freqs in freq_grids:
48+
licu_freqs, licu_power = Periodogram(freqs=freqs, fast=False).freq_power(t, m)
49+
assert_allclose(licu_freqs, freqs)
50+
if base_power is None:
51+
base_power = licu_power
52+
else:
53+
assert_allclose(licu_power[:-1], base_power)
54+
55+
56+
def test_failure_for_wrong_freq_grids():
57+
with pytest.raises(ValueError):
58+
# Too short
59+
Periodogram(freqs=[1.0], fast=False)
60+
with pytest.raises(ValueError):
61+
# Too short
62+
Periodogram(freqs=[1.0], fast=True)
63+
with pytest.raises(ValueError):
64+
# size is not 2**k + 1
65+
Periodogram(freqs=np.linspace(0.0, 100.0, 100), fast=True)
66+
with pytest.raises(ValueError):
67+
# Doesn't start with 0.0
68+
Periodogram(freqs=np.linspace(1.0, 100.0, 257), fast=True)

0 commit comments

Comments
 (0)