|
| 1 | +from collections.abc import Sequence |
| 2 | + |
| 3 | +import jax.numpy as jnp |
| 4 | +from flax import nnx |
| 5 | +from jax.scipy import stats |
| 6 | + |
| 7 | + |
| 8 | +class MLPModel(nnx.Module): |
| 9 | + """Multi-layer Perceptron (MLP) model for the u function |
| 10 | +
|
| 11 | + Parameters |
| 12 | + ---------- |
| 13 | + d_input : int |
| 14 | + Number of input parameters, e.g. length of theta |
| 15 | + d_middle : list of int |
| 16 | + Size of hidden layers, e.g. [64, 32, 16] |
| 17 | + d_output : int |
| 18 | + Number of output parameters, 1 for u, 2 for [u, l]. |
| 19 | + rngs : flax.nnx.Rangs |
| 20 | + Random number generator for parameter initialization. |
| 21 | + """ |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, d_input: int, *, d_middle: Sequence[int] = (300, 300, 400), d_output: int = 1, rngs: nnx.Rngs |
| 25 | + ): |
| 26 | + layers = [] |
| 27 | + dims = [d_input] + list(d_middle) + [d_output] |
| 28 | + for i, (d1, d2) in enumerate(zip(dims[:-1], dims[1:], strict=False)): |
| 29 | + layers.append(nnx.Linear(d1, d2, rngs=rngs, kernel_init=nnx.initializers.normal())) |
| 30 | + if i < len(dims) - 2: # not the last layer |
| 31 | + layers.append(nnx.relu) |
| 32 | + layers.append(nnx.Dropout(0.2, rngs=rngs)) |
| 33 | + self.layers = nnx.Sequential(*layers) |
| 34 | + |
| 35 | + def __call__(self, x): |
| 36 | + """Compute the output of the model""" |
| 37 | + return jnp.exp(-self.layers(x)) |
| 38 | + |
| 39 | + |
| 40 | +def chi2_lc_train_step(model: nnx.Module, optimizer: nnx.Optimizer, theta, flux, err) -> None: |
| 41 | + """Training step on a single light curve, with chi2 probability based loss. |
| 42 | +
|
| 43 | + This gets a single light curve, gets u=model(theta), computes chi-squared |
| 44 | + statistics for a constant-flux model using `flux` and `err`, and uses |
| 45 | + minus logarithm of chi-squared probability as the loss function. |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + model : flax.nnx.Module |
| 50 | + Model to train, input vector size is d_input. |
| 51 | + optimizer : flax.optim.Optimizer |
| 52 | + Optimizer to use for training |
| 53 | + theta : array-like |
| 54 | + Input parameter vector for the model, (n_obs, d_input). |
| 55 | + flux : array-like |
| 56 | + Flux vector, (n_obs,). |
| 57 | + err : array-like |
| 58 | + Error vector, (n_obs,). |
| 59 | +
|
| 60 | + Returns |
| 61 | + ------- |
| 62 | + None |
| 63 | + """ |
| 64 | + |
| 65 | + def minus_lnprob_chi2(model): |
| 66 | + u = model(theta)[:, 0] |
| 67 | + total_err = u * err |
| 68 | + avg_flux = jnp.average(flux, weights=total_err**-2) |
| 69 | + chi2 = jnp.sum(jnp.square((flux - avg_flux) / total_err)) |
| 70 | + lnprob = stats.chi2.logpdf(chi2, jnp.size(flux) - 1) |
| 71 | + return -lnprob |
| 72 | + |
| 73 | + loss, grads = nnx.value_and_grad(minus_lnprob_chi2)(model) |
| 74 | + optimizer.update(model, grads) |
| 75 | + |
| 76 | + return loss |
0 commit comments