Skip to content

Commit b40aa28

Browse files
authored
Merge pull request #407 from karpov-sv/rainbow_ml
Maximum likelihood for Rainbow
2 parents c504d8c + 6250e66 commit b40aa28

File tree

8 files changed

+239
-113
lines changed

8 files changed

+239
-113
lines changed

light-curve/light_curve/light_curve_py/features/rainbow/_base.py

Lines changed: 112 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from abc import abstractmethod
2-
from copy import deepcopy
32
from dataclasses import dataclass
43
from typing import Dict, List, Tuple
54

@@ -11,6 +10,7 @@
1110
from light_curve.light_curve_py.features.rainbow._parameters import create_parameters_class
1211
from light_curve.light_curve_py.features.rainbow._scaler import MultiBandScaler, Scaler
1312
from light_curve.light_curve_py.minuit_lsq import LeastSquares
13+
from light_curve.light_curve_py.minuit_ml import MaximumLikelihood
1414

1515
__all__ = ["BaseRainbowFit"]
1616

@@ -121,6 +121,9 @@ def _check_iminuit():
121121
if LeastSquares is None:
122122
raise ImportError(IMINUIT_IMPORT_ERROR)
123123

124+
if MaximumLikelihood is None:
125+
raise ImportError(IMINUIT_IMPORT_ERROR)
126+
124127
try:
125128
try:
126129
from packaging.version import parse as parse_version
@@ -144,48 +147,65 @@ def temp_func(self, t, params):
144147
"""Temperature evolution function."""
145148
return NotImplementedError
146149

147-
@abstractmethod
148-
def _unscale_parameters(self, params, t_scaler: Scaler, m_scaler: MultiBandScaler) -> None:
149-
"""Unscale parameters from internal units, in-place.
150+
def _parameter_scalings(self) -> Dict[str, str]:
151+
"""Rules for scaling/unscaling the parameters"""
152+
rules = {}
150153

151-
No baseline parameters are needed to be unscaled.
152-
"""
153-
return NotImplementedError
154+
if self.with_baseline:
155+
for band_name in self.bands.names:
156+
baseline_name = self.p.baseline_parameter_name(band_name)
157+
rules[baseline_name] = "baseline"
154158

155-
def _unscale_errors(self, errors, t_scaler: Scaler, m_scaler: MultiBandScaler) -> None:
156-
"""Unscale parameter errors from internal units, in-place.
159+
return rules
157160

158-
No baseline parameters are needed to be unscaled.
159-
"""
161+
def _parameter_scale(self, name: str, t_scaler: Scaler, m_scaler: MultiBandScaler) -> float:
162+
"""Return the scale factor to be applied to the parameter to unscale it"""
163+
scaling = self._parameter_scalings().get(name)
164+
if scaling == "time" or scaling == "timescale":
165+
return t_scaler.scale
166+
elif scaling == "flux":
167+
return m_scaler.scale
160168

161-
# We need to modify original scalers to only apply the scale, not shifts, to the errors
162-
# It should be re-implemented in subclasses for a cleaner way to unscale the errors
163-
t_scaler = deepcopy(t_scaler)
164-
m_scaler = deepcopy(m_scaler)
165-
t_scaler.reset_shift()
166-
m_scaler.reset_shift()
169+
return 1
167170

168-
return self._unscale_parameters(errors, t_scaler, m_scaler)
171+
def _unscale_parameters(self, params, t_scaler: Scaler, m_scaler: MultiBandScaler) -> None:
172+
"""Unscale parameters from internal units, in-place."""
173+
for name, scaling in self._parameter_scalings().items():
174+
if scaling == "time":
175+
params[self.p[name]] = t_scaler.undo_shift_scale(params[self.p[name]])
169176

170-
def _unscale_baseline_parameters(self, params, m_scaler: MultiBandScaler) -> None:
171-
"""Unscale baseline parameters from internal units, in-place.
177+
elif scaling == "timescale":
178+
params[self.p[name]] = t_scaler.undo_scale(params[self.p[name]])
172179

173-
Must be used only if `with_baseline` is True.
174-
"""
175-
for band_name in self.bands.names:
176-
baseline_name = self.p.baseline_parameter_name(band_name)
177-
baseline = params[self.p[baseline_name]]
178-
params[self.p[baseline_name]] = m_scaler.undo_shift_scale_band(baseline, band_name)
180+
elif scaling == "flux":
181+
params[self.p[name]] = m_scaler.undo_scale(params[self.p[name]])
179182

180-
def _unscale_baseline_errors(self, errors, m_scaler: MultiBandScaler) -> None:
181-
"""Unscale baseline parameters from internal units, in-place.
183+
elif scaling == "baseline":
184+
band_name = self.p.baseline_band_name(name)
185+
baseline = params[self.p[name]]
186+
params[self.p[name]] = m_scaler.undo_shift_scale_band(baseline, band_name)
182187

183-
Must be used only if `with_baseline` is True.
184-
"""
185-
for band_name in self.bands.names:
186-
baseline_name = self.p.baseline_parameter_name(band_name)
187-
baseline = errors[self.p[baseline_name]]
188-
errors[self.p[baseline_name]] = m_scaler.undo_scale_band(baseline, band_name)
188+
pass
189+
190+
elif scaling is None or scaling.lower() == "none":
191+
pass
192+
193+
else:
194+
raise ValueError("Unsupported parameter scaling: " + scaling)
195+
196+
def _unscale_errors(self, errors, t_scaler: Scaler, m_scaler: MultiBandScaler) -> None:
197+
"""Unscale parameter errors from internal units, in-place."""
198+
for name in self.names:
199+
scale = self._parameter_scale(name, t_scaler, m_scaler)
200+
errors[self.p[name]] *= scale
201+
202+
def _unscale_covariance(self, cov, t_scaler: Scaler, m_scaler: MultiBandScaler) -> None:
203+
"""Unscale parameter covariance from internal units, in-place."""
204+
for name in self.names:
205+
scale = self._parameter_scale(name, t_scaler, m_scaler)
206+
i = self.p[name]
207+
cov[:, i] *= scale
208+
cov[i, :] *= scale
189209

190210
@staticmethod
191211
def planck_nu(wave_cm, T):
@@ -283,7 +303,19 @@ def _eval(self, *, t, m, sigma, band):
283303
def _eval_and_fill(self, *, t, m, sigma, band, fill_value):
284304
return super()._eval_and_fill(t=t, m=m, sigma=sigma, band=band, fill_value=fill_value)
285305

286-
def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None, get_initial=False):
306+
def _eval_and_get_errors(
307+
self,
308+
*,
309+
t,
310+
m,
311+
sigma,
312+
band,
313+
upper_mask=None,
314+
get_initial=False,
315+
return_covariance=False,
316+
print_level=None,
317+
debug=False,
318+
):
287319
# Initialize data scalers
288320
t_scaler = Scaler.from_time(t)
289321
m_scaler = MultiBandScaler.from_flux(m, band, with_baseline=self.with_baseline)
@@ -311,19 +343,51 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None, get_initi
311343
initial_guesses = self._initial_guesses(t, m, sigma, band)
312344
limits = self._limits(t, m, sigma, band)
313345

314-
least_squares = LeastSquares(
346+
# least_squares = LeastSquares(
347+
cost_function = MaximumLikelihood(
315348
model=self._lsq_model,
316349
parameters=limits,
317350
x=(t, band_idx, wave_cm),
318351
y=m,
319352
yerror=sigma,
353+
upper_mask=upper_mask,
320354
)
321-
minuit = self.Minuit(least_squares, name=self.names, **initial_guesses)
355+
minuit = self.Minuit(cost_function, name=self.names, **initial_guesses)
322356
# TODO: expose these parameters through function arguments
323357
if print_level is not None:
324358
minuit.print_level = print_level
325-
minuit.strategy = 2
326-
minuit.migrad(ncall=10000, iterate=10)
359+
minuit.strategy = 0 # We will need to manually call .hesse() on convergence anyway
360+
361+
# Supposedly it is not the same as just setting iterate=10?..
362+
for i in range(10):
363+
minuit.migrad()
364+
365+
if minuit.valid:
366+
minuit.hesse()
367+
# hesse() may may drive it invalid
368+
if minuit.valid:
369+
break
370+
else:
371+
# That's what iterate is supposed to do?..
372+
minuit.simplex()
373+
# FIXME: it may drive the fit valid, but we will not have Hesse run on last iteration
374+
375+
if debug:
376+
# Expose everything we have to outside, unscaled, for easier debugging
377+
self.minuit = minuit
378+
self.mparams = {
379+
"t": t,
380+
"band_idx": band_idx,
381+
"wave_cm": wave_cm,
382+
"m": m,
383+
"sigma": sigma,
384+
"limits": limits,
385+
"upper_mask": upper_mask,
386+
"initial_guesses": initial_guesses,
387+
"values": minuit.values,
388+
"errors": minuit.errors,
389+
"covariance": minuit.covariance,
390+
}
327391

328392
if not minuit.valid and self.fail_on_divergence and not get_initial:
329393
raise RuntimeError("Fitting failed")
@@ -338,15 +402,19 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None, get_initi
338402
errors = np.array(minuit.errors)
339403

340404
self._unscale_parameters(params, t_scaler, m_scaler)
341-
if self.with_baseline:
342-
self._unscale_baseline_parameters(params, m_scaler)
343405

344406
# Unscale errors
345407
self._unscale_errors(errors, t_scaler, m_scaler)
346-
if self.with_baseline:
347-
self._unscale_baseline_errors(errors, m_scaler)
348408

349-
return np.r_[params, reduced_chi2], errors
409+
return_values = np.r_[params, reduced_chi2], errors
410+
411+
if return_covariance:
412+
# Unscale covaiance
413+
cov = np.array(minuit.covariance)
414+
self._unscale_covariance(cov, t_scaler, m_scaler)
415+
return_values += (cov,)
416+
417+
return return_values
350418

351419
def fit_and_get_errors(self, t, m, sigma, band, *, sorted=None, check=True, **kwargs):
352420
t, m, sigma, band = self._normalize_input(t=t, m=m, sigma=sigma, band=band, sorted=sorted, check=check)

light-curve/light_curve/light_curve_py/features/rainbow/_parameters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ def baseline_parameter_name(band: str) -> str:
1212
return f"baseline_{band}"
1313

1414

15+
def baseline_band_name(name: str) -> str:
16+
if name.startswith("baseline_"):
17+
return name[len("baseline_") :]
18+
19+
return None
20+
21+
1522
def create_int_enum(cls_name: str, attributes: Iterable[str]):
1623
return IntEnum(cls_name, {attr: i for i, attr in enumerate(attributes)})
1724

@@ -68,6 +75,7 @@ def create_parameters_class(
6875
enum.all_baseline = baseline
6976
enum.baseline_idx = np.array([enum[attr] for attr in enum.all_baseline])
7077
enum.baseline_parameter_name = staticmethod(baseline_parameter_name)
78+
enum.baseline_band_name = staticmethod(baseline_band_name)
7179

7280
band_idx_to_baseline_idx = {
7381
band_idx: enum[baseline_parameter_name(band_name)] for band_idx, band_name in zip(bands.index, bands.names)

light-curve/light_curve/light_curve_py/features/rainbow/_scaler.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,6 @@ def do_scale(self, x):
4747
def undo_scale(self, x):
4848
return x * self.scale
4949

50-
def reset_shift(self):
51-
"""Resets scaler shift to zero, keeping only the scale"""
52-
self.shift *= 0
53-
5450

5551
@dataclass()
5652
class MultiBandScaler(Scaler):
@@ -59,9 +55,6 @@ class MultiBandScaler(Scaler):
5955
per_band_shift: Dict[str, float]
6056
"""Shift to apply to each band"""
6157

62-
per_band_scale: Dict[str, float]
63-
"""Scale to apply to each band"""
64-
6558
@classmethod
6659
def from_flux(cls, flux, band, *, with_baseline: bool) -> "MultiBandScaler":
6760
"""Create a Scaler from a flux array.
@@ -71,7 +64,7 @@ def from_flux(cls, flux, band, *, with_baseline: bool) -> "MultiBandScaler":
7164
"""
7265
uniq_bands = np.unique(band)
7366
per_band_shift = dict.fromkeys(uniq_bands, 0.0)
74-
shift_array = np.zeros_like(flux)
67+
shift_array = np.zeros(len(flux))
7568

7669
if with_baseline:
7770
for b in uniq_bands:
@@ -81,19 +74,8 @@ def from_flux(cls, flux, band, *, with_baseline: bool) -> "MultiBandScaler":
8174
scale = np.std(flux)
8275
if scale == 0.0:
8376
scale = 1.0
84-
per_band_scale = dict.fromkeys(uniq_bands, scale)
8577

86-
return cls(shift=shift_array, scale=scale, per_band_shift=per_band_shift, per_band_scale=per_band_scale)
78+
return cls(shift=shift_array, scale=scale, per_band_shift=per_band_shift)
8779

8880
def undo_shift_scale_band(self, x, band):
89-
return x * self.per_band_scale.get(band, 1) + self.per_band_shift.get(band, 0)
90-
91-
def undo_scale_band(self, x, band):
92-
return x * self.per_band_scale.get(band, 1)
93-
94-
def reset_shift(self):
95-
"""Resets scaler shift to zero, keeping only the scale"""
96-
for band in self.per_band_shift:
97-
self.per_band_shift[band] = 0
98-
99-
super().reset_shift()
81+
return x * self.scale + self.per_band_shift.get(band, 0)

light-curve/light_curve/light_curve_py/features/rainbow/bolometric.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def parameter_scalings():
7171
def value(t, t0, amplitude, rise_time):
7272
dt = t - t0
7373

74-
result = np.zeros_like(dt)
74+
result = np.zeros(len(dt))
7575
# To avoid numerical overflows, let's only compute the exponents not too far from t0
7676
idx = dt > -100 * rise_time
7777
result[idx] = amplitude / (np.exp(-dt[idx] / rise_time) + 1)
@@ -80,7 +80,7 @@ def value(t, t0, amplitude, rise_time):
8080

8181
@staticmethod
8282
def initial_guesses(t, m, sigma, band):
83-
A = np.max(m)
83+
A = np.ptp(m)
8484

8585
initial = {}
8686
initial["reference_time"] = t[np.argmax(m)]
@@ -92,12 +92,14 @@ def initial_guesses(t, m, sigma, band):
9292
@staticmethod
9393
def limits(t, m, sigma, band):
9494
t_amplitude = np.ptp(t)
95-
m_amplitude = np.max(m)
95+
m_amplitude = np.ptp(m)
96+
97+
mean_dt = np.median(t[1:] - t[:-1])
9698

9799
limits = {}
98100
limits["reference_time"] = (np.min(t) - 10 * t_amplitude, np.max(t) + 10 * t_amplitude)
99-
limits["amplitude"] = (0.0, 10 * m_amplitude)
100-
limits["rise_time"] = (1e-4, 10 * t_amplitude)
101+
limits["amplitude"] = (0.0, 20 * m_amplitude)
102+
limits["rise_time"] = (0.1 * mean_dt, 10 * t_amplitude)
101103

102104
return limits
103105

@@ -128,7 +130,7 @@ def value(t, t0, amplitude, rise_time, fall_time):
128130
-fall_time / (fall_time + rise_time)
129131
)
130132

131-
result = np.zeros_like(dt)
133+
result = np.zeros(len(dt))
132134
# To avoid numerical overflows, let's only compute the exponents not too far from t0
133135
idx = (dt > -100 * rise_time) & (dt < 100 * fall_time)
134136
result[idx] = amplitude * scale / (np.exp(-dt[idx] / rise_time) + np.exp(dt[idx] / fall_time))
@@ -137,15 +139,17 @@ def value(t, t0, amplitude, rise_time, fall_time):
137139

138140
@staticmethod
139141
def initial_guesses(t, m, sigma, band):
140-
A = np.max(m)
142+
A = np.ptp(m)
143+
144+
mc = m - np.min(m) # To avoid crashing on all-negative data
141145

142146
# Naive peak position from the highest point
143147
t0 = t[np.argmax(m)]
144-
# Peak position as weighted centroid of everything above zero
145-
idx = m > 0
148+
# Peak position as weighted centroid of everything above median
149+
idx = m > np.median(m)
146150
# t0 = np.sum(t[idx] * m[idx] / sigma[idx]) / np.sum(m[idx] / sigma[idx])
147151
# Weighted centroid sigma
148-
dt = np.sqrt(np.sum((t[idx] - t0) ** 2 * m[idx] / sigma[idx]) / np.sum(m[idx] / sigma[idx]))
152+
dt = np.sqrt(np.sum((t[idx] - t0) ** 2 * (mc[idx]) / sigma[idx]) / np.sum(mc[idx] / sigma[idx]))
149153

150154
# Empirical conversion of sigma to rise/fall times
151155
rise_time = dt / 2
@@ -165,13 +169,15 @@ def initial_guesses(t, m, sigma, band):
165169
@staticmethod
166170
def limits(t, m, sigma, band):
167171
t_amplitude = np.ptp(t)
168-
m_amplitude = np.max(m)
172+
m_amplitude = np.ptp(m)
173+
174+
mean_dt = np.median(t[1:] - t[:-1])
169175

170176
limits = {}
171177
limits["reference_time"] = (np.min(t) - 10 * t_amplitude, np.max(t) + 10 * t_amplitude)
172-
limits["amplitude"] = (0.0, 10 * m_amplitude)
173-
limits["rise_time"] = (1e-4, 10 * t_amplitude)
174-
limits["fall_time"] = (1e-4, 10 * t_amplitude)
178+
limits["amplitude"] = (0.0, 20 * m_amplitude)
179+
limits["rise_time"] = (0.1 * mean_dt, 10 * t_amplitude)
180+
limits["fall_time"] = (0.1 * mean_dt, 10 * t_amplitude)
175181

176182
return limits
177183

0 commit comments

Comments
 (0)