Test jaxpm forward modeling on flip#108
Conversation
There was a problem hiding this comment.
Pull request overview
Adds an initial JAX/JaxPM-based forward-modeling pipeline under flip.simulation, plus a demo notebook and a dedicated test suite to validate field generation, interpolation/projection utilities, and likelihood evaluation with gradients.
Changes:
- Introduces
flip.simulationsubmodule (forward model, painters/utilities, Gaussian likelihood, and ajaxopt-based fitter). - Adds an end-to-end demo notebook for generating fields, extracting LOS velocities, and differentiating through the likelihood.
- Adds
pytestcoverage for core simulation/painter/likelihood behaviors (includingjax.gradthrough the pipeline).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
flip/simulation/generator.py |
Implements differentiable IC generation + 1LPT (and optional PM) and exposes ForwardModel. |
flip/simulation/painter.py |
Adds CIC interpolation, coordinate conversions, LOS projection, and RSD shifting utilities. |
flip/simulation/likelihood.py |
Adds Gaussian LOS-velocity likelihood and VelocityFieldLikelihood wrapper around the forward model. |
flip/simulation/fitter.py |
Adds SimulationFitter wrapper around jaxopt solvers to optimize cosmological parameters. |
flip/simulation/cosmo_utils.py |
Adds adapters between flip params and jax_cosmo cosmologies plus P(k) callable builders. |
flip/simulation/__init__.py |
Exposes simulation APIs and implements lazy-loading for optional dependencies. |
test/test_simulation.py |
New tests for generator/painter/likelihood correctness and gradient flow (skips if deps missing). |
notebooks/simulation_forward_model.ipynb |
End-to-end forward-modeling demonstration and usage guide. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| except ImportError: | ||
| import numpy as jnp |
There was a problem hiding this comment.
painter.py falls back to importing numpy as jnp when JAX is unavailable, but apply_rsd uses JAX-only APIs (e.g., positions_box.at[...] and % on DeviceArrays). This will raise runtime errors in a non-JAX environment while the module still imports successfully. Either require JAX explicitly here (raise ImportError) or provide a pure-numpy implementation / guard apply_rsd behind a JAX-availability check.
| except ImportError: | |
| import numpy as jnp | |
| except ImportError as exc: | |
| raise ImportError( | |
| "flip.simulation.painter requires JAX; please install the 'jax' and " | |
| "'jaxlib' packages to use this module." | |
| ) from exc |
| def _to_dict(self, params_array): | ||
| return {k: float(params_array[i]) for i, k in enumerate(self._param_names)} | ||
|
|
There was a problem hiding this comment.
SimulationFitter converts JAX arrays to Python floats in _to_dict (via float(params_array[i])). This breaks jaxopt/JAX transformations (jit/grad) and can raise ConcretizationTypeError, preventing gradients from flowing through the optimizer. Keep values as JAX scalars inside the objective path and only cast to Python floats when returning the final result to the caller.
| simulated_los = painter.compute_los_velocity( | ||
| velocities_3d, self.positions_cartesian | ||
| ) |
There was a problem hiding this comment.
VelocityFieldLikelihood projects LOS velocities using self.positions_cartesian, but the docstring says these positions are in the box frame (origin at box corner) while compute_los_velocity assumes the observer is at the Cartesian origin. This makes the LOS direction inconsistent with the intended observer-at-center convention and yields incorrect simulated_los/likelihood. Consider storing both box-frame and observer-centred positions (or subtracting the observer position/box center before LOS projection) and clarify the expected coordinate frame.
|
|
||
| a_arr = jnp.atleast_1d(a) | ||
| E = jnp.sqrt(background.Esqr(cosmo, a_arr))[0] | ||
| H_a = 100.0 * cosmo.h * E # km/s / (Mpc/h) |
There was a problem hiding this comment.
apply_rsd computes H_a as 100.0 * cosmo.h * E, but positions are in Mpc/h. In Mpc/h units, H(a) should be 100 * E (the h factor cancels); otherwise the RSD displacement is off by a factor of h. Please fix the unit conversion to be consistent with generator._H0_UNIT.
| H_a = 100.0 * cosmo.h * E # km/s / (Mpc/h) | |
| # H(a) in km/s per (Mpc/h): H0_unit = 100 km/s/(Mpc/h), so no extra factor of h | |
| H_a = 100.0 * E |
| except ImportError: | ||
| import numpy as jnp |
There was a problem hiding this comment.
generator.py swallows ImportError for jax/jaxpm and only defines jnp=numpy in the except block, but the module still defines functions that reference jax/jc/jaxpm symbols. Importing flip.simulation.generator without optional deps will then fail later with NameError rather than a clear ImportError. Consider raising a helpful ImportError in this except block (or defining stubs that error on first use) to match the package’s “errors raised at first use” behavior.
| except ImportError: | |
| import numpy as jnp | |
| except ImportError as exc: | |
| raise ImportError( | |
| "flip.simulation.generator requires the optional dependencies " | |
| "'jax', 'jax_cosmo', and 'jaxpm'. Install them via " | |
| "`pip install flip[simulation]` or by installing these packages " | |
| "manually." | |
| ) from exc |
| def apply_rsd(positions_box, velocities, cosmo, a, box_size, los_axis=2): | ||
| """Shift positions along the LOS by the plane-parallel RSD displacement. | ||
|
|
||
| Applies x_rsd = x + v_los / (a H(a)) in Mpc/h, then wraps periodically. | ||
|
|
There was a problem hiding this comment.
apply_rsd introduces nontrivial unit handling (v/(aH)) plus periodic wrapping, but there are currently no tests covering this behavior. Adding a small unit test (e.g., known shift along los_axis and wrap at boundaries) would help prevent regressions, especially around the H(a) unit conversion.
| def run(self): | ||
| """Run the optimization and return the best-fit parameter dict. | ||
|
|
||
| The raw jaxopt result is stored in :attr:`result` after completion. | ||
|
|
||
| Returns: | ||
| dict: Best-fit parameters with the same keys as ``initial_params``. | ||
| """ | ||
| initial_array = self._to_array(self.initial_params) |
There was a problem hiding this comment.
SimulationFitter is new functionality and is not covered by tests. Consider adding a lightweight test (with pytest.importorskip("jaxopt")) that optimizes a simple convex likelihood and asserts convergence, to ensure the dict↔array mapping and solver wiring work as expected.
| # painter only needs jax.numpy — fails gracefully without JAX | ||
| from .painter import ( | ||
| _cic_read, | ||
| apply_rsd, | ||
| cartesian_to_box_frame, |
There was a problem hiding this comment.
init.py states “painter only needs jax.numpy — fails gracefully without JAX”, but painter.py currently falls back to NumPy while still containing JAX-only code paths (e.g., .at in apply_rsd). Please align this comment/docstring with the actual dependency behavior (either require JAX or make the fallback truly work).
Vibe coded structure
Will be improved to perform a full PM forward modeling of the data fields in flip.data_vector.
The idea is to potentially include all potential PV methods in flip. Might be extended with SBI.