Skip to content

Test jaxpm forward modeling on flip#108

Open
corentinravoux wants to merge 1 commit into
mainfrom
simulation/jaxpm-forward-modeling
Open

Test jaxpm forward modeling on flip#108
corentinravoux wants to merge 1 commit into
mainfrom
simulation/jaxpm-forward-modeling

Conversation

@corentinravoux

Copy link
Copy Markdown
Owner

Vibe coded structure
Will be improved to perform a full PM forward modeling of the data fields in flip.data_vector.
The idea is to potentially include all potential PV methods in flip. Might be extended with SBI.

Copilot AI review requested due to automatic review settings April 2, 2026 10:06

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an initial JAX/JaxPM-based forward-modeling pipeline under flip.simulation, plus a demo notebook and a dedicated test suite to validate field generation, interpolation/projection utilities, and likelihood evaluation with gradients.

Changes:

  • Introduces flip.simulation submodule (forward model, painters/utilities, Gaussian likelihood, and a jaxopt-based fitter).
  • Adds an end-to-end demo notebook for generating fields, extracting LOS velocities, and differentiating through the likelihood.
  • Adds pytest coverage for core simulation/painter/likelihood behaviors (including jax.grad through the pipeline).

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
flip/simulation/generator.py Implements differentiable IC generation + 1LPT (and optional PM) and exposes ForwardModel.
flip/simulation/painter.py Adds CIC interpolation, coordinate conversions, LOS projection, and RSD shifting utilities.
flip/simulation/likelihood.py Adds Gaussian LOS-velocity likelihood and VelocityFieldLikelihood wrapper around the forward model.
flip/simulation/fitter.py Adds SimulationFitter wrapper around jaxopt solvers to optimize cosmological parameters.
flip/simulation/cosmo_utils.py Adds adapters between flip params and jax_cosmo cosmologies plus P(k) callable builders.
flip/simulation/__init__.py Exposes simulation APIs and implements lazy-loading for optional dependencies.
test/test_simulation.py New tests for generator/painter/likelihood correctness and gradient flow (skips if deps missing).
notebooks/simulation_forward_model.ipynb End-to-end forward-modeling demonstration and usage guide.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +22 to +23
except ImportError:
import numpy as jnp

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

painter.py falls back to importing numpy as jnp when JAX is unavailable, but apply_rsd uses JAX-only APIs (e.g., positions_box.at[...] and % on DeviceArrays). This will raise runtime errors in a non-JAX environment while the module still imports successfully. Either require JAX explicitly here (raise ImportError) or provide a pure-numpy implementation / guard apply_rsd behind a JAX-availability check.

Suggested change
except ImportError:
import numpy as jnp
except ImportError as exc:
raise ImportError(
"flip.simulation.painter requires JAX; please install the 'jax' and "
"'jaxlib' packages to use this module."
) from exc

Copilot uses AI. Check for mistakes.
Comment thread flip/simulation/fitter.py
Comment on lines +90 to +92
def _to_dict(self, params_array):
return {k: float(params_array[i]) for i, k in enumerate(self._param_names)}

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SimulationFitter converts JAX arrays to Python floats in _to_dict (via float(params_array[i])). This breaks jaxopt/JAX transformations (jit/grad) and can raise ConcretizationTypeError, preventing gradients from flowing through the optimizer. Keep values as JAX scalars inside the objective path and only cast to Python floats when returning the final result to the caller.

Copilot uses AI. Check for mistakes.
Comment on lines +159 to +161
simulated_los = painter.compute_los_velocity(
velocities_3d, self.positions_cartesian
)

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VelocityFieldLikelihood projects LOS velocities using self.positions_cartesian, but the docstring says these positions are in the box frame (origin at box corner) while compute_los_velocity assumes the observer is at the Cartesian origin. This makes the LOS direction inconsistent with the intended observer-at-center convention and yields incorrect simulated_los/likelihood. Consider storing both box-frame and observer-centred positions (or subtracting the observer position/box center before LOS projection) and clarify the expected coordinate frame.

Copilot uses AI. Check for mistakes.

a_arr = jnp.atleast_1d(a)
E = jnp.sqrt(background.Esqr(cosmo, a_arr))[0]
H_a = 100.0 * cosmo.h * E # km/s / (Mpc/h)

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_rsd computes H_a as 100.0 * cosmo.h * E, but positions are in Mpc/h. In Mpc/h units, H(a) should be 100 * E (the h factor cancels); otherwise the RSD displacement is off by a factor of h. Please fix the unit conversion to be consistent with generator._H0_UNIT.

Suggested change
H_a = 100.0 * cosmo.h * E # km/s / (Mpc/h)
# H(a) in km/s per (Mpc/h): H0_unit = 100 km/s/(Mpc/h), so no extra factor of h
H_a = 100.0 * E

Copilot uses AI. Check for mistakes.
Comment on lines +40 to +41
except ImportError:
import numpy as jnp

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generator.py swallows ImportError for jax/jaxpm and only defines jnp=numpy in the except block, but the module still defines functions that reference jax/jc/jaxpm symbols. Importing flip.simulation.generator without optional deps will then fail later with NameError rather than a clear ImportError. Consider raising a helpful ImportError in this except block (or defining stubs that error on first use) to match the package’s “errors raised at first use” behavior.

Suggested change
except ImportError:
import numpy as jnp
except ImportError as exc:
raise ImportError(
"flip.simulation.generator requires the optional dependencies "
"'jax', 'jax_cosmo', and 'jaxpm'. Install them via "
"`pip install flip[simulation]` or by installing these packages "
"manually."
) from exc

Copilot uses AI. Check for mistakes.
Comment on lines +172 to +176
def apply_rsd(positions_box, velocities, cosmo, a, box_size, los_axis=2):
"""Shift positions along the LOS by the plane-parallel RSD displacement.

Applies x_rsd = x + v_los / (a H(a)) in Mpc/h, then wraps periodically.

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_rsd introduces nontrivial unit handling (v/(aH)) plus periodic wrapping, but there are currently no tests covering this behavior. Adding a small unit test (e.g., known shift along los_axis and wrap at boundaries) would help prevent regressions, especially around the H(a) unit conversion.

Copilot uses AI. Check for mistakes.
Comment thread flip/simulation/fitter.py
Comment on lines +96 to +104
def run(self):
"""Run the optimization and return the best-fit parameter dict.

The raw jaxopt result is stored in :attr:`result` after completion.

Returns:
dict: Best-fit parameters with the same keys as ``initial_params``.
"""
initial_array = self._to_array(self.initial_params)

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SimulationFitter is new functionality and is not covered by tests. Consider adding a lightweight test (with pytest.importorskip("jaxopt")) that optimizes a simple convex likelihood and asserts convergence, to ensure the dict↔array mapping and solver wiring work as expected.

Copilot uses AI. Check for mistakes.
Comment on lines +66 to +70
# painter only needs jax.numpy — fails gracefully without JAX
from .painter import (
_cic_read,
apply_rsd,
cartesian_to_box_frame,

Copilot AI Apr 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init.py states “painter only needs jax.numpy — fails gracefully without JAX”, but painter.py currently falls back to NumPy while still containing JAX-only code paths (e.g., .at in apply_rsd). Please align this comment/docstring with the actual dependency behavior (either require JAX or make the fallback truly work).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants