|
16 | 16 | from typing import Union |
17 | 17 |
|
18 | 18 | from astropy.constants import R_sun, M_sun |
| 19 | +from matplotlib.patches import Circle |
19 | 20 | from matplotlib.pyplot import subplots, setp |
20 | | -from numpy import linspace, meshgrid, sin, cos, array, ndarray, asarray, squeeze |
| 21 | +from numpy import linspace, sin, cos, array, ndarray, asarray, squeeze, cross, newaxis, pi, where, nan, full, degrees |
| 22 | +from numpy.linalg import norm |
| 23 | +from scipy.spatial.transform.rotation import Rotation |
21 | 24 |
|
22 | 25 | from .transitmodel import TransitModel |
23 | | -from .numba.osmodel import create_star_xy, create_planet_xy, map_osm, xy_taylor_vt, luminosity_v, oblate_model_s |
24 | | -from ..orbits import i_from_ba |
| 26 | +from .numba.osmodel import create_star_xy, create_planet_xy, map_osm, xy_taylor_vt, luminosity_v, oblate_model_s, \ |
| 27 | + luminosity_v2 |
| 28 | +from ..orbits import as_from_rhop, i_from_baew |
| 29 | +from ..orbits.taylor_z import vajs_from_paiew, find_contact_point |
| 30 | +from ..utils.octasphere import octasphere |
25 | 31 |
|
26 | 32 |
|
27 | 33 | class OblateStarModel(TransitModel): |
@@ -54,55 +60,74 @@ def __init__(self, rstar: float = 1.0, wavelength: float = 510, sres: int = 80, |
54 | 60 | self._ts, self._xs, self._ys = create_star_xy(sres) |
55 | 61 | self._xp, self._yp = create_planet_xy(pres) |
56 | 62 |
|
57 | | - def visualize(self, k, b, alpha, rho, rperiod, tpole, phi, beta, ldc, ires: int = 256): |
58 | | - """Visualize the model for a set of parameters. |
59 | | -
|
60 | | - Parameters |
61 | | - ---------- |
62 | | - k |
63 | | - b |
64 | | - alpha |
65 | | - rho |
66 | | - rperiod |
67 | | - tpole |
68 | | - phi |
69 | | - beta |
70 | | - ldc |
71 | | - ires |
72 | | -
|
73 | | - Returns |
74 | | - ------- |
75 | | -
|
76 | | - """ |
77 | | - a = 4.5 |
78 | | - mstar, ostar, gpole, f, feff = map_osm(self.rstar, rho, rperiod, tpole, phi) |
79 | | - i = i_from_ba(b, a) |
80 | | - times = linspace(-1.1, 1.1) |
81 | | - ox, oy = xy_taylor_vt(times, alpha, -b, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) |
82 | | - |
83 | | - x = linspace(-1.1, 1.1, ires) |
84 | | - y = linspace(-1.1, 1.1, ires) |
85 | | - x, y = meshgrid(x, y) |
86 | | - sphi, cphi = sin(phi), cos(phi) |
87 | | - |
88 | | - l = luminosity_v(x.ravel()*self.rstar, y.ravel()*self.rstar, mstar, self.rstar, ostar, tpole, gpole, |
89 | | - f, sphi, cphi, beta, ldc, self.wavelength) |
90 | | - |
91 | | - fig, axs = subplots(1, 2, figsize=(13, 4)) |
92 | | - axs[0].imshow(l.reshape(x.shape), extent=(-1.1, 1.1, -1.1, 1.1), origin='lower') |
93 | | - axs[0].plot(ox, oy, 'w', lw=5, alpha=0.25) |
94 | | - axs[0].plot(ox, oy, 'k', lw=2) |
95 | | - |
96 | | - setp(axs[0], ylabel='y [R$_\star$]', xlabel='x [R$_\star$]') |
97 | | - |
98 | | - times = linspace(-0.35, 0.35, 500) |
99 | | - flux = oblate_model_s(times, array([k]), 0.0, 4.0, a, alpha, i, 0.0, 0.0, ldc, mstar, self.rstar, ostar, tpole, gpole, |
100 | | - f, feff, sphi, cphi, beta, self.wavelength, self.tres, self._ts, self._xs, self._ys, self._xp, self._yp, |
101 | | - self.lcids, self.pbids, self.nsamples, self.exptimes, self.npb) |
102 | | - |
103 | | - axs[1].plot(times, flux, 'k') |
104 | | - setp(axs[1], ylabel='Normalized flux', xlabel='Time - T$_0$') |
105 | | - fig.tight_layout() |
| 63 | + def visualize(self, k, p, rho, b, e, w, alpha, rperiod, tpole, istar, beta, ldc, figsize=(5, 5), ax=None, |
| 64 | + ntheta=18): |
| 65 | + if ax is None: |
| 66 | + fig, ax = subplots(figsize=figsize) |
| 67 | + ax.set_aspect(1.) |
| 68 | + else: |
| 69 | + fig, ax = None, ax |
| 70 | + |
| 71 | + a = as_from_rhop(rho, p) |
| 72 | + inc = i_from_baew(b, a, e, w) |
| 73 | + mstar, ostar, gpole, f, _ = map_osm(rstar=self.rstar, rho=rho, rperiod=rperiod, tpole=tpole, phi=0.0) |
| 74 | + |
| 75 | + # Plot the star |
| 76 | + # ------------- |
| 77 | + vertices_original, faces = octasphere(4) |
| 78 | + vertices = vertices_original.copy() |
| 79 | + vertices[:, 1] *= (1.0 - f) |
| 80 | + |
| 81 | + triangles = vertices[faces] |
| 82 | + centers = triangles.mean(1) |
| 83 | + normals = cross(triangles[:, 1] - triangles[:, 0], triangles[:, 2] - triangles[:, 0]) |
| 84 | + nlength = norm(normals, axis=1) |
| 85 | + normals /= nlength[:, newaxis] |
| 86 | + |
| 87 | + rotation = Rotation.from_rotvec((0.5 * pi - istar) * array([1, 0, 0])) |
| 88 | + rn = rotation.apply(normals) |
| 89 | + rc = rotation.apply(centers) |
| 90 | + |
| 91 | + mask = rn[:, 2] < 0.0 |
| 92 | + l = luminosity_v2(centers[mask], normals[mask], istar, mstar, self.rstar, ostar, tpole, gpole, beta, |
| 93 | + ldc, self.wavelength) |
| 94 | + ax.tripcolor(rc[mask, 0], rc[mask, 1], l, shading='gouraud') |
| 95 | + |
| 96 | + nphi = 180 |
| 97 | + theta = linspace(0 + 0.1, pi - 0.1, ntheta) |
| 98 | + phi = linspace(0, 2 * pi, nphi) |
| 99 | + for i in range(theta.size): |
| 100 | + y = (1.0 - f) * cos(theta[i]) |
| 101 | + x = cos(phi) * sin(theta[i]) |
| 102 | + z = sin(phi) * sin(theta[i]) |
| 103 | + v = rotation.apply(array([x, full(nphi, y), z]).T) |
| 104 | + m = v[:, 2] < 0.0 |
| 105 | + ax.plot(where(m, v[:, 0], nan), v[:, 1], 'k--', lw=1.5, alpha=0.25) |
| 106 | + |
| 107 | + # Plot the orbit |
| 108 | + # -------------- |
| 109 | + y0, vx, vy, ax_, ay, jx, jy, sx, sy = vajs_from_paiew(p, a, inc, e, w) |
| 110 | + c1 = find_contact_point(k, 1, y0, vx, vy, ax_, ay, jx, jy, sx, sy) |
| 111 | + c4 = find_contact_point(k, 4, y0, vx, vy, ax_, ay, jx, jy, sx, sy) |
| 112 | + time = linspace(2 * c1, 2 * c4, 100) |
| 113 | + |
| 114 | + ox, oy = xy_taylor_vt(time, alpha, y0, vx, vy, ax_, ay, jx, jy, sx, sy) |
| 115 | + ax.plot(ox, oy, 'k') |
| 116 | + |
| 117 | + pxy = xy_taylor_vt(array([0.0]), alpha, y0, vx, vy, ax_, ay, jx, jy, sx, sy) |
| 118 | + ax.add_artist(Circle(pxy, k, zorder=10, fc='k')) |
| 119 | + |
| 120 | + # Plot the info |
| 121 | + # ------------- |
| 122 | + ax.text(0.025, 0.95, f"i$_\star$ = {degrees(istar):.1f}$^\circ$", transform=ax.transAxes) |
| 123 | + ax.text(0.025, 0.90, f"i$_\mathrm{{p}}$ = {degrees(inc):.1f}$^\circ$", transform=ax.transAxes) |
| 124 | + ax.text(1 - 0.025, 0.95, fr"$\alpha$ = {degrees(alpha):.1f}$^\circ$", transform=ax.transAxes, ha='right') |
| 125 | + ax.text(0.025, 0.05, f"f = {f:.1f}", transform=ax.transAxes) |
| 126 | + |
| 127 | + setp(ax, xlim=(-1.1, 1.1), ylim=(-1.1, 1.1), xticks=[], yticks=[]) |
| 128 | + if fig is not None: |
| 129 | + fig.tight_layout() |
| 130 | + return ax |
106 | 131 |
|
107 | 132 | def evaluate_ps(self, k: Union[float, ndarray], rho: float, rperiod: float, tpole: float, phi: float, |
108 | 133 | beta: float, ldc: ndarray, t0: float, p: float, a: float, i: float, l: float = 0.0, |
|
0 commit comments