2828# along with this program. If not, see <https://www.gnu.org/licenses/>.
2929from typing import Union , Optional , List
3030
31- from numpy import ndarray , array , squeeze , atleast_2d , atleast_1d , zeros , asarray
31+ from numpy import ndarray , array , squeeze , atleast_2d , atleast_1d , zeros , asarray , inf
3232
33- from .numba .ma_quadratic_nb import quadratic_model_interpolated_pv , calculate_interpolation_tables , \
34- quadratic_model_interpolated_v , quadratic_model_interpolated_s , quadratic_model_direct_v , quadratic_model_direct_s , \
35- quadratic_model_direct_pv
33+ from .numba .ma_quadratic_nb import quadratic_model_pv , calculate_interpolation_tables , quadratic_model_v , quadratic_model_s
3634from .transitmodel import TransitModel
3735
3836__all__ = ['QuadraticModel' ]
@@ -41,7 +39,7 @@ class QuadraticModel(TransitModel):
4139 """Transit model with quadratic limb darkening (Mandel & Agol, ApJ 580, L171-L175 2002).
4240 """
4341
44- def __init__ (self , interpolate : bool = True , klims : tuple = (0.005 , 0.5 ), nk : int = 256 , nz : int = 256 ):
42+ def __init__ (self , interpolate : bool = True , klims : tuple = (0.01 , 0.5 ), nk : int = 256 , nz : int = 256 ):
4543 """Transit model with quadratic limb darkening (Mandel & Agol, ApJ 580, L171-L175 2002).
4644
4745 Parameters
@@ -61,13 +59,17 @@ def __init__(self, interpolate: bool = True, klims: tuple = (0.005, 0.5), nk: in
6159 # Interpolation tables for the model components
6260 # ---------------------------------------------
6361 if interpolate :
62+ self ._interpolation_initialised = True
6463 self .ed , self .le , self .ld , self .kt , self .zt = calculate_interpolation_tables (klims [0 ], klims [1 ], nk , nz )
6564 self .klims = klims
6665 self .nk = nk
6766 self .nz = nz
6867 else :
69- self .ed , self .le , self .ld , self .kt , self .zt = None , None , None , None , None
70- self .klims , self .nk , self .nz = None , None , None
68+ self ._interpolation_initialised = False
69+ self .ed , self .le , self .ld , self .kt , self .zt = zeros ((2 ,2 )), zeros ((2 ,2 )), zeros ((2 ,2 )), zeros (2 ), zeros (2 )
70+ self .klims = klims
71+ self .nk = 0
72+ self .nz = 0
7173
7274 def evaluate (self , k : Union [float , ndarray ], ldc : Union [ndarray , List ], t0 : Union [float , ndarray ], p : Union [float , ndarray ],
7375 a : Union [float , ndarray ], i : Union [float , ndarray ], e : Optional [Union [float , ndarray ]] = None ,
@@ -125,15 +127,9 @@ def evaluate(self, k: Union[float, ndarray], ldc: Union[ndarray, List], t0: Unio
125127 e = zeros (npv ) if e is None else e
126128 w = zeros (npv ) if w is None else w
127129
128- if self .interpolate :
129- flux = quadratic_model_interpolated_v (self .time , k , t0 , p , a , i , e , w , ldc ,
130- self .lcids , self .pbids , self .nsamples , self .exptimes , self .npb ,
131- self ._es , self ._ms , self ._tae , self .ed , self .ld , self .le ,
132- self .kt , self .zt )
133- else :
134- flux = quadratic_model_direct_v (self .time , k , t0 , p , a , i , e , w , ldc ,
135- self .lcids , self .pbids , self .nsamples , self .exptimes , self .npb ,
136- self ._es , self ._ms , self ._tae )
130+ flux = quadratic_model_v (self .time , k , t0 , p , a , i , e , w , ldc , self .lcids , self .pbids , self .nsamples , self .exptimes , self .npb ,
131+ self .ed , self .ld , self .le , self .kt , self .zt , self .interpolate )
132+
137133 return squeeze (flux )
138134
139135 def evaluate_ps (self , k : Union [float , ndarray ], ldc : ndarray , t0 : float , p : float , a : float , i : float ,
@@ -153,7 +149,7 @@ def evaluate_ps(self, k: Union[float, ndarray], ldc: ndarray, t0: float, p: floa
153149 a : float
154150 Orbital semi-major axis divided by the stellar radius as a float.
155151 i : float
156- Orbital inclination(s) as a float.
152+ Orbital inclination as a float.
157153 e : float, optional
158154 Orbital eccentricity as a float.
159155 w : float, optional
@@ -179,15 +175,8 @@ def evaluate_ps(self, k: Union[float, ndarray], ldc: ndarray, t0: float, p: floa
179175 if ldc .size != 2 * self .npb :
180176 raise ValueError ("The quadratic model needs two limb darkening coefficients per passband" )
181177
182- if self .interpolate :
183- flux = quadratic_model_interpolated_s (self .time , k , t0 , p , a , i , e , w , ldc ,
184- self .lcids , self .pbids , self .nsamples , self .exptimes , self .npb ,
185- self ._es , self ._ms , self ._tae , self .ed , self .ld , self .le ,
186- self .kt , self .zt )
187- else :
188- flux = quadratic_model_direct_s (self .time , k , t0 , p , a , i , e , w , ldc ,
189- self .lcids , self .pbids , self .nsamples , self .exptimes , self .npb ,
190- self ._es , self ._ms , self ._tae )
178+ flux = quadratic_model_s (self .time , k , t0 , p , a , i , e , w , ldc , self .lcids , self .pbids , self .nsamples , self .exptimes , self .npb ,
179+ self .ed , self .ld , self .le , self .kt , self .zt , self .interpolate )
191180 return squeeze (flux )
192181
193182 def evaluate_pv (self , pvp : ndarray , ldc : ndarray , copy : bool = True ) -> ndarray :
@@ -198,7 +187,7 @@ def evaluate_pv(self, pvp: ndarray, ldc: ndarray, copy: bool = True) -> ndarray:
198187 pvp: ndarray
199188 Parameter array with a shape `(npv, npar)` where `npv` is the number of parameter vectors, and each row
200189 contains a set of parameters `[k, t0, p, a, i, e, w]`. The radius ratios can also be given per passband,
201- in which case the row should be structured as `[k_0, k_1, k_2, ..., k_npb, t0, p, a, i , e, w]`.
190+ in which case the row should be structured as `[k_0, k_1, k_2, ..., k_npb, t0, p, a, b , e, w]`.
202191 ldc: ndarray
203192 Limb darkening coefficient array with shape `(npv, 2*npb)`, where `npv` is the number of parameter vectors
204193 and `npb` is the number of passbands.
@@ -220,13 +209,8 @@ def evaluate_pv(self, pvp: ndarray, ldc: ndarray, copy: bool = True) -> ndarray:
220209 if self .time is None :
221210 raise ValueError ("Need to set the data before calling the transit model." )
222211
223- if self .interpolate :
224- flux = quadratic_model_interpolated_pv (self .time , pvp , ldc , self .lcids , self .pbids , self .nsamples , self .exptimes ,
225- self .npb , self ._es , self ._ms , self ._tae , self .ed , self .ld , self .le ,
226- self .kt , self .zt )
227- else :
228- flux = quadratic_model_direct_pv (self .time , pvp , ldc , self .lcids , self .pbids , self .nsamples , self .exptimes ,
229- self .npb , self ._es , self ._ms , self ._tae )
212+ flux = quadratic_model_pv (self .time , pvp , ldc , self .lcids , self .pbids , self .nsamples , self .exptimes ,
213+ self .npb , self .ed , self .ld , self .le , self .kt , self .zt , self .interpolate )
230214 return squeeze (flux )
231215
232216 def to_opencl (self ):
0 commit comments