Skip to content

Commit 7ce644d

Browse files
committed
Init ML model
1 parent 8993b0b commit 7ce644d

File tree

5 files changed

+160
-0
lines changed

5 files changed

+160
-0
lines changed

.github/workflows/testing-and-coverage.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ jobs:
2020

2121
steps:
2222
- uses: actions/checkout@v4
23+
- name: Cache JAX compilation cache
24+
uses: actions/cache@v4
25+
with:
26+
path: /tmp/jax_cache
27+
key: jax-cache-${{ runner.os }}
28+
restore-keys: |
29+
jax-cache-
2330
- name: Set up Python ${{ matrix.python-version }}
2431
uses: actions/setup-python@v5
2532
with:

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ classifiers = [
1717
dynamic = ["version"]
1818
requires-python = ">=3.10"
1919
dependencies = [
20+
"flax>=0.11",
21+
"jax",
2022
"nested-pandas>=0.5,<0.6",
2123
"numpy>=2",
24+
"optax>=0.2.6",
2225
"scipy>=1",
2326
]
2427

src/uncle_val/models.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from collections.abc import Sequence
2+
3+
import jax.numpy as jnp
4+
from flax import nnx
5+
from jax.scipy import stats
6+
7+
8+
class MLPModel(nnx.Module):
9+
"""Multi-layer Perceptron (MLP) model for the u function
10+
11+
Parameters
12+
----------
13+
d_input : int
14+
Number of input parameters, e.g. length of theta
15+
d_middle : list of int
16+
Size of hidden layers, e.g. [64, 32, 16]
17+
d_output : int
18+
Number of output parameters, 1 for u, 2 for [u, l].
19+
rngs : flax.nnx.Rangs
20+
Random number generator for parameter initialization.
21+
"""
22+
23+
def __init__(
24+
self, d_input: int, *, d_middle: Sequence[int] = (300, 300, 400), d_output: int = 1, rngs: nnx.Rngs
25+
):
26+
layers = []
27+
dims = [d_input] + list(d_middle) + [d_output]
28+
for i, (d1, d2) in enumerate(zip(dims[:-1], dims[1:], strict=False)):
29+
layers.append(nnx.Linear(d1, d2, rngs=rngs, kernel_init=nnx.initializers.normal()))
30+
if i < len(dims) - 2: # not the last layer
31+
layers.append(nnx.relu)
32+
layers.append(nnx.Dropout(0.2, rngs=rngs))
33+
self.layers = nnx.Sequential(*layers)
34+
35+
def __call__(self, x):
36+
"""Compute the output of the model"""
37+
return jnp.exp(-self.layers(x))
38+
39+
40+
def chi2_lc_train_step(model: nnx.Module, optimizer: nnx.Optimizer, theta, flux, err) -> None:
41+
"""Training step on a single light curve, with chi2 probability based loss.
42+
43+
This gets a single light curve, gets u=model(theta), computes chi-squared
44+
statistics for a constant-flux model using `flux` and `err`, and uses
45+
minus logarithm of chi-squared probability as the loss function.
46+
47+
Parameters
48+
----------
49+
model : flax.nnx.Module
50+
Model to train, input vector size is d_input.
51+
optimizer : flax.optim.Optimizer
52+
Optimizer to use for training
53+
theta : array-like
54+
Input parameter vector for the model, (n_obs, d_input).
55+
flux : array-like
56+
Flux vector, (n_obs,).
57+
err : array-like
58+
Error vector, (n_obs,).
59+
60+
Returns
61+
-------
62+
None
63+
"""
64+
65+
def minus_lnprob_chi2(model):
66+
u = model(theta)[:, 0]
67+
total_err = u * err
68+
avg_flux = jnp.average(flux, weights=total_err**-2)
69+
chi2 = jnp.sum(jnp.square((flux - avg_flux) / total_err))
70+
lnprob = stats.chi2.logpdf(chi2, jnp.size(flux) - 1)
71+
return -lnprob
72+
73+
loss, grads = nnx.value_and_grad(minus_lnprob_chi2)(model)
74+
optimizer.update(model, grads)
75+
76+
return loss

tests/uncle_val/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import jax
2+
3+
4+
def set_jax_compilation_cache():
5+
"""Setup Jax compilation cache to speed up tests."""
6+
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
7+
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
8+
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
9+
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
10+
11+
12+
set_jax_compilation_cache()

tests/uncle_val/test_models.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import jax.numpy as jnp
2+
import numpy as np
3+
import optax
4+
from flax import nnx
5+
from numpy.testing import assert_allclose
6+
from uncle_val.datasets import fake_non_variable_lcs
7+
from uncle_val.models import MLPModel, chi2_lc_train_step
8+
9+
10+
def test_mlp_model():
11+
"""Fit MLPModel for a constant u function"""
12+
np_rng = np.random.default_rng(42)
13+
nnx_rngs = nnx.Rngs(int(np_rng.integers(1 << 63)))
14+
15+
train_steps = 2000
16+
n_obj = 1000
17+
n_src = np_rng.integers(30, 150, size=n_obj)
18+
u = 2.0
19+
20+
nf = fake_non_variable_lcs(
21+
n_obj=n_obj,
22+
n_src=n_src,
23+
err=None,
24+
u=u,
25+
rng=np_rng,
26+
)
27+
ln_fluxes = np.log(nf["objectForcedSource.psfFlux"])
28+
nf["objectForcedSource.norm_flux"] = (ln_fluxes - np.mean(ln_fluxes)) / np.std(ln_fluxes)
29+
ln_errs = np.log(nf["objectForcedSource.psfFluxErr"])
30+
nf["objectForcedSource.norm_err"] = (ln_errs - np.mean(ln_errs)) / np.std(ln_errs)
31+
32+
struct_array = nf["objectForcedSource"].array.struct_array.combine_chunks()
33+
flux_arr = struct_array.field("psfFlux")
34+
err_arr = struct_array.field("psfFluxErr")
35+
norm_flux_arr = struct_array.field("norm_flux")
36+
norm_err_arr = struct_array.field("norm_err")
37+
38+
model = MLPModel(
39+
d_input=2,
40+
d_output=1,
41+
rngs=nnx_rngs,
42+
)
43+
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
44+
45+
step = nnx.jit(chi2_lc_train_step)
46+
47+
for idx in np_rng.choice(len(flux_arr), train_steps):
48+
flux = jnp.asarray(flux_arr[idx].values)
49+
err = jnp.asarray(err_arr[idx].values)
50+
norm_flux = jnp.asarray(norm_flux_arr[idx].values)
51+
norm_err = jnp.asarray(norm_err_arr[idx].values)
52+
53+
theta = jnp.stack([norm_flux, norm_err], axis=-1)
54+
step(
55+
model=model,
56+
optimizer=optimizer,
57+
theta=theta,
58+
flux=flux,
59+
err=err,
60+
)
61+
62+
assert_allclose(np.asarray(model(theta)), u, rtol=0.1)

0 commit comments

Comments
 (0)