Skip to content

Conversation

BSnelling
Copy link
Collaborator

Implements support for petab sciml test suite: https://github.com/sebapersson/petab_sciml_testsuite/tree/main

Includes support for all test_net cases and a subset of test_ude. Test cases with frozen layers and networks in the observable formulae are not yet implemented.

- update sciml testsuite submodule to point at main
- fix a eqx LayerNorm deprecation warning
- fix string formatting of bool in kwarg
- updates to test code driven by updated sciml format
…ludes:

- frozen nn layers
- nns in observable formulae
@BSnelling BSnelling requested a review from a team as a code owner September 2, 2025 13:33
@BSnelling BSnelling marked this pull request as draft September 2, 2025 13:39
Copy link

codecov bot commented Sep 2, 2025

Codecov Report

❌ Patch coverage is 98.97959% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 78.21%. Comparing base (b73cf60) to head (0115d07).

Files with missing lines Patch % Lines
python/sdist/amici/petab/petab_import.py 90.90% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@              Coverage Diff               @@
##           jax_sciml    #2947       +/-   ##
==============================================
+ Coverage      36.61%   78.21%   +41.60%     
==============================================
  Files            326      326               
  Lines          22613    22720      +107     
  Branches        1527     1527               
==============================================
+ Hits            8279    17770     +9491     
+ Misses         14309     4941     -9368     
+ Partials          25        9       -16     
Flag Coverage Δ
cpp 74.63% <27.55%> (?)
cpp_python 33.16% <4.08%> (-0.33%) ⬇️
petab 14.91% <97.95%> (?)
python 72.87% <27.55%> (+40.95%) ⬆️
sbmlsuite-jax 32.37% <12.24%> (-0.30%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
python/sdist/amici/de_model.py 93.19% <100.00%> (+8.38%) ⬆️
python/sdist/amici/jax/model.py 80.57% <100.00%> (+16.03%) ⬆️
python/sdist/amici/jax/nn.py 97.61% <100.00%> (+83.92%) ⬆️
python/sdist/amici/jax/ode_export.py 88.33% <ø> (+5.00%) ⬆️
python/sdist/amici/jax/petab.py 81.49% <100.00%> (+63.03%) ⬆️
python/sdist/amici/petab/parameter_mapping.py 57.07% <100.00%> (-6.97%) ⬇️
python/sdist/amici/petab/petab_import.py 85.71% <90.90%> (-14.29%) ⬇️

... and 253 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

][:],
dtype=jnp.float64,
),
) # ?? hardcoded dtype not ideal ?? could infer from env somehow ??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it really necessary to set dtype here? Usually jax infers float precision from https://docs.jax.dev/en/latest/config_options.html. Might be necessary to cast this as numpy array first if conversion from hdf5 is the problem

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried casting to a numpy array but the array persisted as float32s. I've defined the dtype based on the current jax config settings which I think is better than hard coding.

petab.NOMINAL_VALUE,
],
)
if "input"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check seems a bit too unspecific. I think you want to construct a sequence of petab id's that are mapped to $nnId.inputs{[$inputArgumentIndex]{[$inputIndex]}}?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this. The complication is that the values here could be pulled from a nominal value in the parameter table or from a value in the conditions table, depending on whether the id appears in the parameters table. That's my understanding anyway.

generalise frozen layers to networks across system

Use stop_grad instead
dfs.append(df_sc)
return pd.concat(dfs).sort_index()

def apply_grad_filter(problem: JAXProblem,):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great solution! Only thing that I am a bit worried about is that in the current implementation stop_gradient is only applied when calling problem methods in the context of run_simulations, which may lead to confusion when trying to compute gradients outside of that context. My interpretation of the petab problem definition is that setting estimate=0 means that gradient computation is permanently disabled and we should apply apply_grad_filter during JaxProblem instantiation.

@FFroehlich
Copy link
Member

just checking test failures:

  • Notebook tests, also fails on the base branch albeit for a different reasons, but this is not related to changes here
  • mac os, this looks like an issue with PRs from a fork
  • doc tests: also problem in base branch, h5py is missing from doc requirements
  • sbml jax: unrelated, also failing in base branch

@FFroehlich
Copy link
Member

  • mac os, this looks like an issue with PRs from a fork

this is probably not related to failures from forks, but rather CMAKE: #2949 (review)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants