Skip to content

Commit 2d699ac

Browse files
erusseilEtienne Russeilpre-commit-ci[bot]
authored
Improved initial guesses/limits of Rainbow functions (#494)
* Improved initial guesses/limits of Rainbow functions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed the tests from parameter comparison to flux comparison * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reformat file * reformat again :) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Etienne Russeil <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 054fbb6 commit 2d699ac

File tree

3 files changed

+68
-48
lines changed

3 files changed

+68
-48
lines changed

light-curve/light_curve/light_curve_py/features/rainbow/bolometric.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@ def limits(t, m, sigma, band):
102102
t_amplitude = np.ptp(t)
103103
m_amplitude = np.ptp(m)
104104

105-
mean_dt = np.median(t[1:] - t[:-1])
105+
_, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
106106

107107
limits = {}
108108
limits["reference_time"] = (np.min(t) - 10 * t_amplitude, np.max(t) + 10 * t_amplitude)
109109
limits["amplitude"] = (0.0, 20 * m_amplitude)
110-
limits["rise_time"] = (0.1 * mean_dt, 10 * t_amplitude)
110+
limits["rise_time"] = (dt / 100, 10 * t_amplitude)
111111

112112
return limits
113113

@@ -147,24 +147,13 @@ def value(t, t0, amplitude, rise_time, fall_time):
147147

148148
@staticmethod
149149
def initial_guesses(t, m, sigma, band):
150-
A = np.ptp(m)
151-
152-
mc = m - np.min(m) # To avoid crashing on all-negative data
150+
A = 1.5 * max(np.max(m), np.ptp(m))
153151

154-
# Naive peak position from the highest point
155-
t0 = t[np.argmax(m)]
156-
# Peak position as weighted centroid of everything above median
157-
idx = m > np.median(m)
158-
# t0 = np.sum(t[idx] * m[idx] / sigma[idx]) / np.sum(m[idx] / sigma[idx])
159-
# Weighted centroid sigma
160-
dt = np.sqrt(np.sum((t[idx] - t0) ** 2 * (mc[idx]) / sigma[idx]) / np.sum(mc[idx] / sigma[idx]))
152+
t0, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
161153

162154
# Empirical conversion of sigma to rise/fall times
163-
rise_time = dt / 2
164-
fall_time = dt / 2
165-
166-
# Compensate for the difference between reference_time and peak position
167-
t0 -= np.log(fall_time / rise_time) * rise_time * fall_time / (rise_time + fall_time)
155+
rise_time = dt
156+
fall_time = dt
168157

169158
initial = {}
170159
initial["reference_time"] = t0
@@ -178,14 +167,13 @@ def initial_guesses(t, m, sigma, band):
178167
def limits(t, m, sigma, band):
179168
t_amplitude = np.ptp(t)
180169
m_amplitude = np.ptp(m)
181-
182-
mean_dt = np.median(t[1:] - t[:-1])
170+
_, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
183171

184172
limits = {}
185173
limits["reference_time"] = (np.min(t) - 10 * t_amplitude, np.max(t) + 10 * t_amplitude)
186174
limits["amplitude"] = (0.0, 20 * m_amplitude)
187-
limits["rise_time"] = (0.1 * mean_dt, 10 * t_amplitude)
188-
limits["fall_time"] = (0.1 * mean_dt, 10 * t_amplitude)
175+
limits["rise_time"] = (dt / 100, 10 * t_amplitude)
176+
limits["fall_time"] = (dt / 100, 10 * t_amplitude)
189177

190178
return limits
191179

@@ -198,7 +186,7 @@ def peak_time(t0, amplitude, rise_time, fall_time):
198186
class LinexpBolometricTerm(BaseBolometricTerm):
199187
"""Linexp function, symmetric form. Generated using a prototype version of Multi-view
200188
Symbolic Regression (Russeil et al. 2024, https://arxiv.org/abs/2402.04298) on
201-
a SLSN ZTF light curve (https://ztf.snad.space/dr17/view/821207100004043)"""
189+
a SLSN ZTF light curve (https://ztf.snad.space/dr17/view/821207100004043). Careful not very stable guesses/limits"""
202190

203191
@staticmethod
204192
def parameter_names():
@@ -226,6 +214,7 @@ def value(t, t0, amplitude, rise_time):
226214
def initial_guesses(t, m, sigma, band):
227215
A = np.ptp(m)
228216
med_dt = median_dt(t, band)
217+
t0, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
229218

230219
# Compute points after or before maximum
231220
peak_time = t[np.argmax(m)]
@@ -276,7 +265,6 @@ def parameter_scalings():
276265
@staticmethod
277266
def value(t, t0, amplitude, time1, time2, p):
278267
dt = t - t0
279-
280268
result = np.zeros_like(dt)
281269

282270
# To avoid numerical overflows
@@ -290,37 +278,34 @@ def value(t, t0, amplitude, time1, time2, p):
290278

291279
@staticmethod
292280
def initial_guesses(t, m, sigma, band):
293-
A = np.ptp(m)
294-
med_dt = median_dt(t, band)
281+
A = max(np.max(m), np.ptp(m))
282+
t0, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
295283

296-
# Naive peak position from the highest point
297-
t0 = t[np.argmax(m)]
298-
299-
# Empirical conversion of sigma to rise/fall times
300-
time1 = 50 * med_dt
301-
time2 = 50 * med_dt
284+
# Empirical conversion of sigma to times
285+
time1 = 2 * dt
286+
time2 = 2 * dt
302287

303288
initial = {}
304289
initial["reference_time"] = t0
305290
initial["amplitude"] = A
306291
initial["time1"] = time1
307292
initial["time2"] = time2
308-
initial["p"] = 0.1
293+
initial["p"] = 1
309294

310295
return initial
311296

312297
@staticmethod
313298
def limits(t, m, sigma, band):
314299
t_amplitude = np.ptp(t)
315300
m_amplitude = np.ptp(m)
316-
med_dt = median_dt(t, band)
301+
_, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
317302

318303
limits = {}
319304
limits["reference_time"] = (np.min(t) - 10 * t_amplitude, np.max(t) + 10 * t_amplitude)
320305
limits["amplitude"] = (0.0, 10 * m_amplitude)
321-
limits["time1"] = (med_dt, 2 * t_amplitude)
322-
limits["time2"] = (med_dt, 2 * t_amplitude)
323-
limits["p"] = (1e-4, 10)
306+
limits["time1"] = (dt / 10, 2 * t_amplitude)
307+
limits["time2"] = (dt / 10, 2 * t_amplitude)
308+
limits["p"] = (1e-2, 100)
324309

325310
return limits
326311

@@ -336,13 +321,27 @@ def peak_time(t0, p):
336321

337322
def median_dt(t, band):
338323
# Compute the median distance between points in each band
324+
# Caution when using this method as it might be strongly biaised because of ZTF high cadence a given day.
339325
dt = []
340326
for b in np.unique(band):
341327
dt += list(t[band == b][1:] - t[band == b][:-1])
342328
med_dt = np.median(dt)
343329
return med_dt
344330

345331

332+
def t0_and_weighted_centroid_sigma(t, m, sigma):
333+
# To avoid crashing on all-negative data
334+
mc = m - np.min(m)
335+
336+
# Peak position as weighted centroid of everything above median
337+
idx = m > np.median(m)
338+
t0 = np.sum(t[idx] * m[idx] / sigma[idx]) / np.sum(m[idx] / sigma[idx])
339+
340+
# Weighted centroid sigma
341+
dt = np.sqrt(np.sum((t[idx] - t0) ** 2 * (mc[idx]) / sigma[idx]) / np.sum(mc[idx] / sigma[idx]))
342+
return t0, dt
343+
344+
346345
bolometric_terms = {
347346
"sigmoid": SigmoidBolometricTerm,
348347
"bazin": BazinBolometricTerm,

light-curve/light_curve/light_curve_py/features/rainbow/temperature.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def initial_guesses(t, m, sigma, band):
7676
@staticmethod
7777
def limits(t, m, sigma, band):
7878
limits = {}
79-
limits["T"] = (1e2, 2e6) # K
79+
limits["T"] = (1e3, 2e6) # K
8080

8181
return limits
8282

@@ -111,24 +111,24 @@ def value(t, t0, temp_min, temp_max, t_color):
111111

112112
@staticmethod
113113
def initial_guesses(t, m, sigma, band):
114-
med_dt = median_dt(t, band)
114+
_, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
115115

116116
initial = {}
117117
initial["Tmin"] = 7000.0
118118
initial["Tmax"] = 10000.0
119-
initial["t_color"] = 10 * med_dt
119+
initial["t_color"] = 2 * dt
120120

121121
return initial
122122

123123
@staticmethod
124124
def limits(t, m, sigma, band):
125125
t_amplitude = np.ptp(t)
126-
med_dt = median_dt(t, band)
126+
_, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
127127

128128
limits = {}
129129
limits["Tmin"] = (1e3, 2e6) # K
130130
limits["Tmax"] = (1e3, 2e6) # K
131-
limits["t_color"] = (2 * med_dt, 10 * t_amplitude)
131+
limits["t_color"] = (dt / 3, 10 * t_amplitude)
132132

133133
return limits
134134

@@ -163,25 +163,25 @@ def value(t, t0, Tmin, Tmax, t_color, t_delay):
163163

164164
@staticmethod
165165
def initial_guesses(t, m, sigma, band):
166-
med_dt = median_dt(t, band)
166+
_, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
167167

168168
initial = {}
169169
initial["Tmin"] = 7000.0
170170
initial["Tmax"] = 10000.0
171-
initial["t_color"] = 10 * med_dt
171+
initial["t_color"] = 2 * dt
172172
initial["t_delay"] = 0.0
173173

174174
return initial
175175

176176
@staticmethod
177177
def limits(t, m, sigma, band):
178178
t_amplitude = np.ptp(t)
179-
med_dt = median_dt(t, band)
179+
_, dt = t0_and_weighted_centroid_sigma(t, m, sigma)
180180

181181
limits = {}
182182
limits["Tmin"] = (1e3, 2e6) # K
183183
limits["Tmax"] = (1e3, 2e6) # K
184-
limits["t_color"] = (2 * med_dt, 10 * t_amplitude)
184+
limits["t_color"] = (dt / 3, 10 * t_amplitude)
185185
limits["t_delay"] = (-t_amplitude, t_amplitude)
186186

187187
return limits
@@ -196,6 +196,19 @@ def median_dt(t, band):
196196
return med_dt
197197

198198

199+
def t0_and_weighted_centroid_sigma(t, m, sigma):
200+
# To avoid crashing on all-negative data
201+
mc = m - np.min(m)
202+
203+
# Peak position as weighted centroid of everything above median
204+
idx = m > np.median(m)
205+
t0 = np.sum(t[idx] * m[idx] / sigma[idx]) / np.sum(m[idx] / sigma[idx])
206+
207+
# Weighted centroid sigma
208+
dt = np.sqrt(np.sum((t[idx] - t0) ** 2 * (mc[idx]) / sigma[idx]) / np.sum(mc[idx] / sigma[idx]))
209+
return t0, dt
210+
211+
199212
temperature_terms = {
200213
"constant": ConstantTemperatureTerm,
201214
"sigmoid": SigmoidTemperatureTerm,

light-curve/tests/light_curve_py/features/test_rainbow.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ def test_noisy_with_baseline():
1515
fall_time = 30.0
1616
Tmin = 5e3
1717
Tmax = 15e3
18-
k_sig = 4.0
18+
t_color = 10
1919
baselines = {b: 0.3 * amplitude + rng.exponential(scale=0.3 * amplitude) for b in band_wave_aa}
2020

21-
expected = [reference_time, amplitude, rise_time, fall_time, Tmin, Tmax, k_sig, *baselines.values(), 1.0]
21+
expected = [reference_time, amplitude, rise_time, fall_time, Tmin, Tmax, t_color, *baselines.values(), 1.0]
2222

2323
feature = RainbowFit.from_angstrom(band_wave_aa, with_baseline=True, temperature="sigmoid", bolometric="bazin")
2424

@@ -40,7 +40,7 @@ def test_noisy_with_baseline():
4040
# plt.legend()
4141
# plt.show()
4242

43-
np.testing.assert_allclose(actual[:-1], expected[:-1], rtol=0.1)
43+
np.testing.assert_allclose(feature.model(t, band, *expected), feature.model(t, band, *actual), rtol=0.1)
4444

4545

4646
def test_noisy_all_functions_combination():
@@ -113,8 +113,16 @@ def test_noisy_all_functions_combination():
113113
# plt.legend()
114114
# plt.show()
115115

116+
# The first test might be too rigid. The second test allow for good local minima to be accepted
116117
np.testing.assert_allclose(actual[:-1], expected[:-1], rtol=0.1)
117118

119+
# If either the absolute or the relative test passes, it is accepted.
120+
# It prevents linexp, which include a flat exactly 0 baseline to not pass the test because
121+
# of very minor parameter differences that lead to a major relative difference.
122+
np.testing.assert_allclose(
123+
feature.model(t, band, *expected), feature.model(t, band, *actual), rtol=0.1, atol=0.1, strict=False
124+
)
125+
118126

119127
def test_scaler_from_flux_list_input():
120128
"https://github.com/light-curve/light-curve-python/issues/492"

0 commit comments

Comments
 (0)