1
1
from abc import abstractmethod
2
- from copy import deepcopy
3
2
from dataclasses import dataclass
4
3
from typing import Dict , List , Tuple
5
4
11
10
from light_curve .light_curve_py .features .rainbow ._parameters import create_parameters_class
12
11
from light_curve .light_curve_py .features .rainbow ._scaler import MultiBandScaler , Scaler
13
12
from light_curve .light_curve_py .minuit_lsq import LeastSquares
13
+ from light_curve .light_curve_py .minuit_ml import MaximumLikelihood
14
14
15
15
__all__ = ["BaseRainbowFit" ]
16
16
@@ -121,6 +121,9 @@ def _check_iminuit():
121
121
if LeastSquares is None :
122
122
raise ImportError (IMINUIT_IMPORT_ERROR )
123
123
124
+ if MaximumLikelihood is None :
125
+ raise ImportError (IMINUIT_IMPORT_ERROR )
126
+
124
127
try :
125
128
try :
126
129
from packaging .version import parse as parse_version
@@ -144,48 +147,65 @@ def temp_func(self, t, params):
144
147
"""Temperature evolution function."""
145
148
return NotImplementedError
146
149
147
- @ abstractmethod
148
- def _unscale_parameters ( self , params , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
149
- """Unscale parameters from internal units, in-place.
150
+ def _parameter_scalings ( self ) -> Dict [ str , str ]:
151
+ """Rules for scaling/unscaling the parameters"""
152
+ rules = {}
150
153
151
- No baseline parameters are needed to be unscaled.
152
- """
153
- return NotImplementedError
154
+ if self .with_baseline :
155
+ for band_name in self .bands .names :
156
+ baseline_name = self .p .baseline_parameter_name (band_name )
157
+ rules [baseline_name ] = "baseline"
154
158
155
- def _unscale_errors (self , errors , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
156
- """Unscale parameter errors from internal units, in-place.
159
+ return rules
157
160
158
- No baseline parameters are needed to be unscaled.
159
- """
161
+ def _parameter_scale (self , name : str , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> float :
162
+ """Return the scale factor to be applied to the parameter to unscale it"""
163
+ scaling = self ._parameter_scalings ().get (name )
164
+ if scaling == "time" or scaling == "timescale" :
165
+ return t_scaler .scale
166
+ elif scaling == "flux" :
167
+ return m_scaler .scale
160
168
161
- # We need to modify original scalers to only apply the scale, not shifts, to the errors
162
- # It should be re-implemented in subclasses for a cleaner way to unscale the errors
163
- t_scaler = deepcopy (t_scaler )
164
- m_scaler = deepcopy (m_scaler )
165
- t_scaler .reset_shift ()
166
- m_scaler .reset_shift ()
169
+ return 1
167
170
168
- return self ._unscale_parameters (errors , t_scaler , m_scaler )
171
+ def _unscale_parameters (self , params , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
172
+ """Unscale parameters from internal units, in-place."""
173
+ for name , scaling in self ._parameter_scalings ().items ():
174
+ if scaling == "time" :
175
+ params [self .p [name ]] = t_scaler .undo_shift_scale (params [self .p [name ]])
169
176
170
- def _unscale_baseline_parameters ( self , params , m_scaler : MultiBandScaler ) -> None :
171
- """Unscale baseline parameters from internal units, in-place.
177
+ elif scaling == "timescale" :
178
+ params [ self . p [ name ]] = t_scaler . undo_scale ( params [ self . p [ name ]])
172
179
173
- Must be used only if `with_baseline` is True.
174
- """
175
- for band_name in self .bands .names :
176
- baseline_name = self .p .baseline_parameter_name (band_name )
177
- baseline = params [self .p [baseline_name ]]
178
- params [self .p [baseline_name ]] = m_scaler .undo_shift_scale_band (baseline , band_name )
180
+ elif scaling == "flux" :
181
+ params [self .p [name ]] = m_scaler .undo_scale (params [self .p [name ]])
179
182
180
- def _unscale_baseline_errors (self , errors , m_scaler : MultiBandScaler ) -> None :
181
- """Unscale baseline parameters from internal units, in-place.
183
+ elif scaling == "baseline" :
184
+ band_name = self .p .baseline_band_name (name )
185
+ baseline = params [self .p [name ]]
186
+ params [self .p [name ]] = m_scaler .undo_shift_scale_band (baseline , band_name )
182
187
183
- Must be used only if `with_baseline` is True.
184
- """
185
- for band_name in self .bands .names :
186
- baseline_name = self .p .baseline_parameter_name (band_name )
187
- baseline = errors [self .p [baseline_name ]]
188
- errors [self .p [baseline_name ]] = m_scaler .undo_scale_band (baseline , band_name )
188
+ pass
189
+
190
+ elif scaling is None or scaling .lower () == "none" :
191
+ pass
192
+
193
+ else :
194
+ raise ValueError ("Unsupported parameter scaling: " + scaling )
195
+
196
+ def _unscale_errors (self , errors , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
197
+ """Unscale parameter errors from internal units, in-place."""
198
+ for name in self .names :
199
+ scale = self ._parameter_scale (name , t_scaler , m_scaler )
200
+ errors [self .p [name ]] *= scale
201
+
202
+ def _unscale_covariance (self , cov , t_scaler : Scaler , m_scaler : MultiBandScaler ) -> None :
203
+ """Unscale parameter covariance from internal units, in-place."""
204
+ for name in self .names :
205
+ scale = self ._parameter_scale (name , t_scaler , m_scaler )
206
+ i = self .p [name ]
207
+ cov [:, i ] *= scale
208
+ cov [i , :] *= scale
189
209
190
210
@staticmethod
191
211
def planck_nu (wave_cm , T ):
@@ -283,7 +303,19 @@ def _eval(self, *, t, m, sigma, band):
283
303
def _eval_and_fill (self , * , t , m , sigma , band , fill_value ):
284
304
return super ()._eval_and_fill (t = t , m = m , sigma = sigma , band = band , fill_value = fill_value )
285
305
286
- def _eval_and_get_errors (self , * , t , m , sigma , band , print_level = None , get_initial = False ):
306
+ def _eval_and_get_errors (
307
+ self ,
308
+ * ,
309
+ t ,
310
+ m ,
311
+ sigma ,
312
+ band ,
313
+ upper_mask = None ,
314
+ get_initial = False ,
315
+ return_covariance = False ,
316
+ print_level = None ,
317
+ debug = False ,
318
+ ):
287
319
# Initialize data scalers
288
320
t_scaler = Scaler .from_time (t )
289
321
m_scaler = MultiBandScaler .from_flux (m , band , with_baseline = self .with_baseline )
@@ -311,19 +343,51 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None, get_initi
311
343
initial_guesses = self ._initial_guesses (t , m , sigma , band )
312
344
limits = self ._limits (t , m , sigma , band )
313
345
314
- least_squares = LeastSquares (
346
+ # least_squares = LeastSquares(
347
+ cost_function = MaximumLikelihood (
315
348
model = self ._lsq_model ,
316
349
parameters = limits ,
317
350
x = (t , band_idx , wave_cm ),
318
351
y = m ,
319
352
yerror = sigma ,
353
+ upper_mask = upper_mask ,
320
354
)
321
- minuit = self .Minuit (least_squares , name = self .names , ** initial_guesses )
355
+ minuit = self .Minuit (cost_function , name = self .names , ** initial_guesses )
322
356
# TODO: expose these parameters through function arguments
323
357
if print_level is not None :
324
358
minuit .print_level = print_level
325
- minuit .strategy = 2
326
- minuit .migrad (ncall = 10000 , iterate = 10 )
359
+ minuit .strategy = 0 # We will need to manually call .hesse() on convergence anyway
360
+
361
+ # Supposedly it is not the same as just setting iterate=10?..
362
+ for i in range (10 ):
363
+ minuit .migrad ()
364
+
365
+ if minuit .valid :
366
+ minuit .hesse ()
367
+ # hesse() may may drive it invalid
368
+ if minuit .valid :
369
+ break
370
+ else :
371
+ # That's what iterate is supposed to do?..
372
+ minuit .simplex ()
373
+ # FIXME: it may drive the fit valid, but we will not have Hesse run on last iteration
374
+
375
+ if debug :
376
+ # Expose everything we have to outside, unscaled, for easier debugging
377
+ self .minuit = minuit
378
+ self .mparams = {
379
+ "t" : t ,
380
+ "band_idx" : band_idx ,
381
+ "wave_cm" : wave_cm ,
382
+ "m" : m ,
383
+ "sigma" : sigma ,
384
+ "limits" : limits ,
385
+ "upper_mask" : upper_mask ,
386
+ "initial_guesses" : initial_guesses ,
387
+ "values" : minuit .values ,
388
+ "errors" : minuit .errors ,
389
+ "covariance" : minuit .covariance ,
390
+ }
327
391
328
392
if not minuit .valid and self .fail_on_divergence and not get_initial :
329
393
raise RuntimeError ("Fitting failed" )
@@ -338,15 +402,19 @@ def _eval_and_get_errors(self, *, t, m, sigma, band, print_level=None, get_initi
338
402
errors = np .array (minuit .errors )
339
403
340
404
self ._unscale_parameters (params , t_scaler , m_scaler )
341
- if self .with_baseline :
342
- self ._unscale_baseline_parameters (params , m_scaler )
343
405
344
406
# Unscale errors
345
407
self ._unscale_errors (errors , t_scaler , m_scaler )
346
- if self .with_baseline :
347
- self ._unscale_baseline_errors (errors , m_scaler )
348
408
349
- return np .r_ [params , reduced_chi2 ], errors
409
+ return_values = np .r_ [params , reduced_chi2 ], errors
410
+
411
+ if return_covariance :
412
+ # Unscale covaiance
413
+ cov = np .array (minuit .covariance )
414
+ self ._unscale_covariance (cov , t_scaler , m_scaler )
415
+ return_values += (cov ,)
416
+
417
+ return return_values
350
418
351
419
def fit_and_get_errors (self , t , m , sigma , band , * , sorted = None , check = True , ** kwargs ):
352
420
t , m , sigma , band = self ._normalize_input (t = t , m = m , sigma = sigma , band = band , sorted = sorted , check = check )
0 commit comments