Skip to content

Commit 8c08d91

Browse files
committed
- Improvement: Improved oblate star model visualisation.
1 parent 4d30a4b commit 8c08d91

9 files changed

Lines changed: 425 additions & 81 deletions

File tree

.idea/inspectionProfiles/Project_Default.xml

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

notebooks/osmodel_example_1.ipynb

Lines changed: 12 additions & 19 deletions
Large diffs are not rendered by default.

notebooks/osmodel_visualization_example.ipynb

Lines changed: 142 additions & 0 deletions
Large diffs are not rendered by default.

pytransit/models/numba/osmodel.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,72 @@ def luminosity_v(xs, ys, mstar, rstar, ostar, tpole, gpole, f, sphi, cphi, beta,
154154
l[i] = planck(wavelength, t)*(1. - ldc[0]*(1. - mu) - ldc[1]*(1. - mu)**2)
155155
return l
156156

157+
@njit
158+
def luminosity_v2(ps, normals, istar, mstar, rstar, ostar, tpole, gpole, beta, ldc, wavelength):
159+
npt = ps.shape[0]
160+
l = zeros(npt)
161+
dc = zeros(3)
162+
163+
vx = 0.0
164+
vy = -cos(istar)
165+
vz = -sin(istar)
166+
167+
for i in range(npt):
168+
px, py, pz = ps[i] * rstar # Position vector components
169+
nx, ny, nz = normals[i] # Normal vector components
170+
171+
mu = vy*ny + vz*nz
172+
173+
lp2 = (px**2 + py**2 + pz**2) # Squared distance from center
174+
lc = sqrt(px**2 + pz**2) # Centrifugal vector length
175+
cx, cz = px/lc, pz/lc # Normalized centrifugal vector
176+
177+
gg = -G * mstar / lp2 # Newtionian surface gravity component
178+
gc = ostar * ostar * lc # Centrifugal surface gravity component
179+
180+
gx = gg*nx + gc*cx # Surface gravity x component
181+
gy = gg*ny # Surface gravity y component
182+
gz = gg*nz + gc*cz # Surface gravity z component
183+
184+
g = sqrt((gx**2 + gy**2 + gz**2)) # Surface gravity
185+
t = tpole*g**beta / gpole**beta # Temperature [K]
186+
l[i] = planck(wavelength, t) # Thermal radiation
187+
l[i] *= (1.-ldc[0]*(1.-mu) - ldc[1]*(1.-mu)**2) # Quadratic limb darkening
188+
189+
190+
return l
191+
192+
193+
@njit
194+
def luminosity_s2(p, normal, istar, mstar, rstar, ostar, tpole, gpole, beta, ldc, wavelength):
195+
196+
vx = 0.0
197+
vy = -cos(istar)
198+
vz = -sin(istar)
199+
200+
px, py, pz = p * rstar # Position vector components
201+
nx, ny, nz = normal # Normal vector components
202+
203+
mu = vy*ny + vz*nz
204+
205+
lp2 = (px**2 + py**2 + pz**2) # Squared distance from center
206+
lc = sqrt(px**2 + pz**2) # Centrifugal vector length
207+
cx, cz = px/lc, pz/lc # Normalized centrifugal vector
208+
209+
gg = -G * mstar / lp2 # Newtionian surface gravity component
210+
gc = ostar * ostar * lc # Centrifugal surface gravity component
211+
212+
gx = gg*nx + gc*cx # Surface gravity x component
213+
gy = gg*ny # Surface gravity y component
214+
gz = gg*nz + gc*cz # Surface gravity z component
215+
216+
g = sqrt((gx**2 + gy**2 + gz**2)) # Surface gravity
217+
t = tpole*g**beta / gpole**beta # Temperature [K]
218+
l = planck(wavelength, t)
219+
l *= (1.-ldc[0]*(1.-mu) - ldc[1]*(1.-mu)**2) # Quadratic limb darkening
220+
221+
return l
222+
157223

158224
def create_star_xy(res: int = 64):
159225
st = linspace(-1., 1., res)

pytransit/models/osmodel.py

Lines changed: 77 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,18 @@
1616
from typing import Union
1717

1818
from astropy.constants import R_sun, M_sun
19+
from matplotlib.patches import Circle
1920
from matplotlib.pyplot import subplots, setp
20-
from numpy import linspace, meshgrid, sin, cos, array, ndarray, asarray, squeeze
21+
from numpy import linspace, sin, cos, array, ndarray, asarray, squeeze, cross, newaxis, pi, where, nan, full, degrees
22+
from numpy.linalg import norm
23+
from scipy.spatial.transform.rotation import Rotation
2124

2225
from .transitmodel import TransitModel
23-
from .numba.osmodel import create_star_xy, create_planet_xy, map_osm, xy_taylor_vt, luminosity_v, oblate_model_s
24-
from ..orbits import i_from_ba
26+
from .numba.osmodel import create_star_xy, create_planet_xy, map_osm, xy_taylor_vt, luminosity_v, oblate_model_s, \
27+
luminosity_v2
28+
from ..orbits import as_from_rhop, i_from_baew
29+
from ..orbits.taylor_z import vajs_from_paiew, find_contact_point
30+
from ..utils.octasphere import octasphere
2531

2632

2733
class OblateStarModel(TransitModel):
@@ -54,55 +60,74 @@ def __init__(self, rstar: float = 1.0, wavelength: float = 510, sres: int = 80,
5460
self._ts, self._xs, self._ys = create_star_xy(sres)
5561
self._xp, self._yp = create_planet_xy(pres)
5662

57-
def visualize(self, k, b, alpha, rho, rperiod, tpole, phi, beta, ldc, ires: int = 256):
58-
"""Visualize the model for a set of parameters.
59-
60-
Parameters
61-
----------
62-
k
63-
b
64-
alpha
65-
rho
66-
rperiod
67-
tpole
68-
phi
69-
beta
70-
ldc
71-
ires
72-
73-
Returns
74-
-------
75-
76-
"""
77-
a = 4.5
78-
mstar, ostar, gpole, f, feff = map_osm(self.rstar, rho, rperiod, tpole, phi)
79-
i = i_from_ba(b, a)
80-
times = linspace(-1.1, 1.1)
81-
ox, oy = xy_taylor_vt(times, alpha, -b, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
82-
83-
x = linspace(-1.1, 1.1, ires)
84-
y = linspace(-1.1, 1.1, ires)
85-
x, y = meshgrid(x, y)
86-
sphi, cphi = sin(phi), cos(phi)
87-
88-
l = luminosity_v(x.ravel()*self.rstar, y.ravel()*self.rstar, mstar, self.rstar, ostar, tpole, gpole,
89-
f, sphi, cphi, beta, ldc, self.wavelength)
90-
91-
fig, axs = subplots(1, 2, figsize=(13, 4))
92-
axs[0].imshow(l.reshape(x.shape), extent=(-1.1, 1.1, -1.1, 1.1), origin='lower')
93-
axs[0].plot(ox, oy, 'w', lw=5, alpha=0.25)
94-
axs[0].plot(ox, oy, 'k', lw=2)
95-
96-
setp(axs[0], ylabel='y [R$_\star$]', xlabel='x [R$_\star$]')
97-
98-
times = linspace(-0.35, 0.35, 500)
99-
flux = oblate_model_s(times, array([k]), 0.0, 4.0, a, alpha, i, 0.0, 0.0, ldc, mstar, self.rstar, ostar, tpole, gpole,
100-
f, feff, sphi, cphi, beta, self.wavelength, self.tres, self._ts, self._xs, self._ys, self._xp, self._yp,
101-
self.lcids, self.pbids, self.nsamples, self.exptimes, self.npb)
102-
103-
axs[1].plot(times, flux, 'k')
104-
setp(axs[1], ylabel='Normalized flux', xlabel='Time - T$_0$')
105-
fig.tight_layout()
63+
def visualize(self, k, p, rho, b, e, w, alpha, rperiod, tpole, istar, beta, ldc, figsize=(5, 5), ax=None,
64+
ntheta=18):
65+
if ax is None:
66+
fig, ax = subplots(figsize=figsize)
67+
ax.set_aspect(1.)
68+
else:
69+
fig, ax = None, ax
70+
71+
a = as_from_rhop(rho, p)
72+
inc = i_from_baew(b, a, e, w)
73+
mstar, ostar, gpole, f, _ = map_osm(rstar=self.rstar, rho=rho, rperiod=rperiod, tpole=tpole, phi=0.0)
74+
75+
# Plot the star
76+
# -------------
77+
vertices_original, faces = octasphere(4)
78+
vertices = vertices_original.copy()
79+
vertices[:, 1] *= (1.0 - f)
80+
81+
triangles = vertices[faces]
82+
centers = triangles.mean(1)
83+
normals = cross(triangles[:, 1] - triangles[:, 0], triangles[:, 2] - triangles[:, 0])
84+
nlength = norm(normals, axis=1)
85+
normals /= nlength[:, newaxis]
86+
87+
rotation = Rotation.from_rotvec((0.5 * pi - istar) * array([1, 0, 0]))
88+
rn = rotation.apply(normals)
89+
rc = rotation.apply(centers)
90+
91+
mask = rn[:, 2] < 0.0
92+
l = luminosity_v2(centers[mask], normals[mask], istar, mstar, self.rstar, ostar, tpole, gpole, beta,
93+
ldc, self.wavelength)
94+
ax.tripcolor(rc[mask, 0], rc[mask, 1], l, shading='gouraud')
95+
96+
nphi = 180
97+
theta = linspace(0 + 0.1, pi - 0.1, ntheta)
98+
phi = linspace(0, 2 * pi, nphi)
99+
for i in range(theta.size):
100+
y = (1.0 - f) * cos(theta[i])
101+
x = cos(phi) * sin(theta[i])
102+
z = sin(phi) * sin(theta[i])
103+
v = rotation.apply(array([x, full(nphi, y), z]).T)
104+
m = v[:, 2] < 0.0
105+
ax.plot(where(m, v[:, 0], nan), v[:, 1], 'k--', lw=1.5, alpha=0.25)
106+
107+
# Plot the orbit
108+
# --------------
109+
y0, vx, vy, ax_, ay, jx, jy, sx, sy = vajs_from_paiew(p, a, inc, e, w)
110+
c1 = find_contact_point(k, 1, y0, vx, vy, ax_, ay, jx, jy, sx, sy)
111+
c4 = find_contact_point(k, 4, y0, vx, vy, ax_, ay, jx, jy, sx, sy)
112+
time = linspace(2 * c1, 2 * c4, 100)
113+
114+
ox, oy = xy_taylor_vt(time, alpha, y0, vx, vy, ax_, ay, jx, jy, sx, sy)
115+
ax.plot(ox, oy, 'k')
116+
117+
pxy = xy_taylor_vt(array([0.0]), alpha, y0, vx, vy, ax_, ay, jx, jy, sx, sy)
118+
ax.add_artist(Circle(pxy, k, zorder=10, fc='k'))
119+
120+
# Plot the info
121+
# -------------
122+
ax.text(0.025, 0.95, f"i$_\star$ = {degrees(istar):.1f}$^\circ$", transform=ax.transAxes)
123+
ax.text(0.025, 0.90, f"i$_\mathrm{{p}}$ = {degrees(inc):.1f}$^\circ$", transform=ax.transAxes)
124+
ax.text(1 - 0.025, 0.95, fr"$\alpha$ = {degrees(alpha):.1f}$^\circ$", transform=ax.transAxes, ha='right')
125+
ax.text(0.025, 0.05, f"f = {f:.1f}", transform=ax.transAxes)
126+
127+
setp(ax, xlim=(-1.1, 1.1), ylim=(-1.1, 1.1), xticks=[], yticks=[])
128+
if fig is not None:
129+
fig.tight_layout()
130+
return ax
106131

107132
def evaluate_ps(self, k: Union[float, ndarray], rho: float, rperiod: float, tpole: float, phi: float,
108133
beta: float, ldc: ndarray, t0: float, p: float, a: float, i: float, l: float = 0.0,

pytransit/utils/octasphere.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# PyTransit: fast and easy exoplanet transit modelling in Python.
2+
# Copyright (C) 2010-2021 Hannu Parviainen
3+
#
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the GNU General Public License as published by
6+
# the Free Software Foundation, either version 3 of the License, or
7+
# (at your option) any later version.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
# GNU General Public License for more details.
13+
#
14+
# You should have received a copy of the GNU General Public License
15+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
16+
17+
# This script can generate spheres, rounded cubes, and capsules.
18+
# For more information, see https://prideout.net/blog/octasphere/
19+
# Copyright (c) 2019 Philip Rideout
20+
# Distributed under the MIT License, see bottom of file.
21+
22+
from math import sin, cos, acos, pi
23+
from numpy import empty, array, vstack, cross, dot
24+
from pyrr import quaternion
25+
26+
27+
def octasphere(ndivisions: int):
28+
"""Creates a unit sphere using octagon subdivision.
29+
30+
Creates a unit sphere using octagon subdivision. Modified slightly from the original code
31+
by Philip Rideout (https://prideout.net/blog/octasphere).
32+
"""
33+
34+
n = 2**ndivisions + 1
35+
num_verts = n * (n + 1) // 2
36+
verts = empty((num_verts, 3))
37+
j = 0
38+
for i in range(n):
39+
theta = pi * 0.5 * i / (n - 1)
40+
point_a = [0, sin(theta), cos(theta)]
41+
point_b = [cos(theta), sin(theta), 0]
42+
num_segments = n - 1 - i
43+
j = compute_geodesic(verts, j, point_a, point_b, num_segments)
44+
assert len(verts) == num_verts
45+
46+
num_faces = (n - 2) * (n - 1) + n - 1
47+
faces = empty((num_faces, 3), dtype='int')
48+
f, j0 = 0, 0
49+
for col_index in range(n-1):
50+
col_height = n - 1 - col_index
51+
j1 = j0 + 1
52+
j2 = j0 + col_height + 1
53+
j3 = j0 + col_height + 2
54+
for row in range(col_height - 1):
55+
faces[f + 0] = [j0 + row, j1 + row, j2 + row]
56+
faces[f + 1] = [j2 + row, j1 + row, j3 + row]
57+
f = f + 2
58+
row = col_height - 1
59+
faces[f] = [j0 + row, j1 + row, j2 + row]
60+
f = f + 1
61+
j0 = j2
62+
63+
euler_angles = array([
64+
[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
65+
[1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3],
66+
]) * pi * 0.5
67+
quats = (quaternion.create_from_eulers(e) for e in euler_angles)
68+
69+
offset, combined_verts, combined_faces = 0, [], []
70+
for quat in quats:
71+
rotated_verts = [quaternion.apply_to_vector(quat, v) for v in verts]
72+
rotated_faces = faces + offset
73+
combined_verts.append(rotated_verts)
74+
combined_faces.append(rotated_faces)
75+
offset = offset + len(verts)
76+
77+
return vstack(combined_verts), vstack(combined_faces)
78+
79+
80+
def compute_geodesic(dst, index, point_a, point_b, num_segments):
81+
"""Given two points on a unit sphere, returns a sequence of surface
82+
points that lie between them along a geodesic curve."""
83+
84+
angle_between_endpoints = acos(dot(point_a, point_b))
85+
rotation_axis = cross(point_a, point_b)
86+
dst[index] = point_a
87+
index = index + 1
88+
if num_segments == 0:
89+
return index
90+
dtheta = angle_between_endpoints / num_segments
91+
for point_index in range(1, num_segments):
92+
theta = point_index * dtheta
93+
q = quaternion.create_from_axis_rotation(rotation_axis, theta)
94+
dst[index] = quaternion.apply_to_vector(q, point_a)
95+
index = index + 1
96+
dst[index] = point_b
97+
return index + 1

pytransit/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616

1717
from semantic_version import Version
1818

19-
__version__ = Version('2.5.3')
19+
__version__ = Version('2.5.4')

requirements.txt

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
numpy>=1.19.0
22
scipy>=1.5.0
3-
pandas~=1.1.1
4-
xarray~=0.12.3
3+
pandas~=1.2.2
4+
xarray~=0.12.1
55
tables
66
uncertainties~=3.1.1
7-
numba~=0.50.1
8-
astropy~=4.0.1.post1
7+
numba~=0.51.2
8+
astropy~=4.2
99
matplotlib~=3.3.1
10-
tqdm~=4.48.2
10+
tqdm~=4.31.1
1111
semantic_version>=2.8
12-
setuptools~=49.6.0
12+
setuptools~=52.0.0
1313
deprecated~=1.2.10
14-
seaborn
15-
emcee
14+
seaborn~=0.11.1
15+
emcee~=0.0.0
16+
pytransit~=2.5.2
17+
ldtk~=1.0
18+
pyopencl~=2019.1.2
19+
corner~=2.0.1
20+
celerite~=0.4.0
21+
pyrr~=0.10.3

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
'pytransit.utils', 'pytransit.param', 'pytransit.contamination','pytransit.lpf', 'pytransit.lpf.tess',
3232
'pytransit.lpf.baselines','pytransit.lpf.loglikelihood'],
3333
package_data={'':['*.cl'], 'pytransit.contamination':['data/*']},
34-
install_requires=["numpy", "numba", "scipy", "pandas", "xarray", "tables", "semantic_version","deprecated", "uncertainties"],
34+
install_requires=["numpy", "numba", "scipy", "pandas", "xarray", "tables", "semantic_version", "deprecated", "uncertainties"],
3535
extras_require={'celerite': ["celerite","pybind11"]},
3636
include_package_data=True,
3737
license='GPLv2',

0 commit comments

Comments
 (0)