-
Notifications
You must be signed in to change notification settings - Fork 30
PEtab SciML test suite - test_net and (some) test_ude cases #2947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: jax_sciml
Are you sure you want to change the base?
Conversation
- update sciml testsuite submodule to point at main - fix a eqx LayerNorm deprecation warning - fix string formatting of bool in kwarg - updates to test code driven by updated sciml format
…ludes: - frozen nn layers - nns in observable formulae
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## jax_sciml #2947 +/- ##
==============================================
+ Coverage 36.61% 78.21% +41.60%
==============================================
Files 326 326
Lines 22613 22720 +107
Branches 1527 1527
==============================================
+ Hits 8279 17770 +9491
+ Misses 14309 4941 -9368
+ Partials 25 9 -16
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
python/sdist/amici/jax/petab.py
Outdated
][:], | ||
dtype=jnp.float64, | ||
), | ||
) # ?? hardcoded dtype not ideal ?? could infer from env somehow ?? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it really necessary to set dtype
here? Usually jax infers float precision from https://docs.jax.dev/en/latest/config_options.html. Might be necessary to cast this as numpy array first if conversion from hdf5 is the problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried casting to a numpy array but the array persisted as float32s. I've defined the dtype
based on the current jax config settings which I think is better than hard coding.
python/sdist/amici/jax/petab.py
Outdated
petab.NOMINAL_VALUE, | ||
], | ||
) | ||
if "input" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check seems a bit too unspecific. I think you want to construct a sequence of petab id's that are mapped to $nnId.inputs{[$inputArgumentIndex]{[$inputIndex]}}
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated this. The complication is that the values here could be pulled from a nominal value in the parameter table or from a value in the conditions table, depending on whether the id appears in the parameters table. That's my understanding anyway.
generalise frozen layers to networks across system Use stop_grad instead
python/sdist/amici/jax/petab.py
Outdated
dfs.append(df_sc) | ||
return pd.concat(dfs).sort_index() | ||
|
||
def apply_grad_filter(problem: JAXProblem,): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great solution! Only thing that I am a bit worried about is that in the current implementation stop_gradient
is only applied when calling problem
methods in the context of run_simulations
, which may lead to confusion when trying to compute gradients outside of that context. My interpretation of the petab problem definition is that setting estimate=0
means that gradient computation is permanently disabled and we should apply apply_grad_filter
during JaxProblem
instantiation.
just checking test failures:
|
this is probably not related to failures from forks, but rather CMAKE: #2949 (review) |
Implements support for petab sciml test suite: https://github.com/sebapersson/petab_sciml_testsuite/tree/main
Includes support for all
test_net
cases and a subset oftest_ude
. Test cases with frozen layers and networks in the observable formulae are not yet implemented.