Improve memory issues#180
Conversation
Changed key creation step inline with jax-ml/jax#17432 to fix memory leak when repeatedly calling estimators.
No changes to functionality, just tidying up previously added .DS_Store files.
Same as last commit, trying to remove remaining .DS_Store files without more appearing
Finally seem to have removed all of the .DS_Store files
There was a problem hiding this comment.
Hi,
thanks a lot for the improvements! I like them a lot 🙂
I only left some very minor comments, corresponding to some code clarity improvements. I think we can merge it as soon as they are addressed and the PR passes the GitHub CI checks 🙂
(I think that running pre-commit as described here will resolve automatically all (or at least most) of the CI's complaints.)
Thank you once again!
| from bmi.interface import BaseModel, IMutualInformationPointEstimator | ||
| from bmi.utils import ProductSpace | ||
|
|
||
| import gc # new |
There was a problem hiding this comment.
The comment # new will become redundant in a few months, so I'd suggest removing it from this PR.
| ) | ||
| keys = jax.random.split(rng, max_n_steps) | ||
| for n_step, key in enumerate(keys, start=1): | ||
| # keys = jax.random.split(rng, max_n_steps) |
There was a problem hiding this comment.
Very nice fix! Would you mind adding more in-text context, e.g.,
# We don't use
# keys = jax.random.split(rng, max_n_steps)
# because of memory leaks. See:
# https://github.com/jax-ml/jax/issues/17432for future reference, so we don't forget the reason why to avoid this in the future?
| self.layers.append(eqx.nn.Linear(dims[-1], 1, key=key_final)) | ||
|
|
||
| def __call__(self, x: Point, y: Point) -> jax.Array: | ||
| # print(f"Critic - x shape {x.shape}, y shape {y.shape}") |
There was a problem hiding this comment.
I think this can be removed.
| from bmi.interface import BaseModel, EstimateResult, IMutualInformationPointEstimator | ||
| from bmi.utils import ProductSpace | ||
|
|
||
| import gc # new |
There was a problem hiding this comment.
The comment can be removed here as well.
| xs_batch.delete() | ||
| ys_batch_paired.delete() | ||
| ys_batch_unpaired.delete() | ||
| # xs_test.delete() |
There was a problem hiding this comment.
Can # xs_test.detele() be removed as well?
| ys_batch_unpaired.delete() | ||
| # xs_test.delete() | ||
| # ys_test_unpaired.delete() | ||
| del xs_batch, ys_batch_paired, ys_batch_unpaired #, xs_test, ys_test_unpaired |
There was a problem hiding this comment.
(Similarly here: the commented out variables should be removed).
| del xs_batch, ys_batch_paired, ys_batch_unpaired #, xs_test, ys_test_unpaired | ||
|
|
||
| training_log.finish() | ||
| jax.clear_caches() # clears jit/compilation & staging caches |
There was a problem hiding this comment.
Very nice fix and explanation!
Removed '# new' from 'gc' imports and previous changes. Updated information regarding memory leaks in neural network basic training file and mentioned its change in MINE estimation. Removed deprecated for loops.
Changed key creation step in line with jax-ml/jax#17432 to fix memory leak when repeatedly calling estimators i.e.
bmi.estimators.HistogramEstimator(). The updates are only to the estimator computation functions where the keys, x and y are directly used i.e. the functionestimate()within classHistogramEstimatorParams(BaseModel).Files for issue replication:
bmi_estimate_mi.py
memory_test.py