-
Notifications
You must be signed in to change notification settings - Fork 13
Updating Healpix CUDA primitive #290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hello @matt-graham @jasonmcewen @CosmoMatt Just a quick PR to wrap up a few stuff
And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work) There is an issue with building pyssht not sure that this is my fault I will check the failing worflows when I get the chance, but in the meantime a review is appreciated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @matt-graham @jasonmcewen @CosmoMatt
Just a quick PR to wrap up a few stuff
1. Updated the binding API to the newest [FFI](https://docs.jax.dev/en/latest/ffi.html) 2. Added a vmap implementation of the cuda primitive 3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell) 4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100 5. Removed two files which are now no longer needed with the FFI API ([kernel helpers](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h)) (so maybe they should be removed from the license section) 6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [[BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982](https://github.com/wjakob/nanobind/issues/982)
And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)
There is an issue with building pyssht not sure that this is my fault
I will check the failing worflows when I get the chance, but in the meantime a review is appreciated
Hi @ASKabalan, sorry for the delay in getting back to you.
This all sounds great - thanks for picking up #237 in particular and for the updates to use the newer FFI interface.
With regards to the failing workflows - this was probably due to #292 which was fixed in #293. If you merge in latest main
here that should hopefully resolve the upstream dependency build problems that were causing the test workflows to fail.
I've added some initial review comments below. Will have a closer look next week and try testing this out, but don't have access to GPU machine atm.
tests/test_healpix_ffts.py
Outdated
flm_hp = samples.flm_2d_to_hp(flm, L) | ||
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could use s2fft.inverse(flm, L=L, reality=False, method="jax", sampling="healpix")
here instead of going via healpy
? Rationale being that I would have a slight preference for minimising the number of additional tests that depend on healpy
as it we are no longer requiring it as direct dependency for package and in the long run it might be possible to also remove it as a test dependency.
Co-authored-by: Matt Graham <[email protected]>
I've tried building, installing and running this on a system with CUDA 12.6 + a NVIDIA A100, and running the HEALPix FFT tests with
consistently the tests hang when trying to run the first Running just the IFFT tests with
the tests for both set of test parameters pass. Trying to dig into this a bit, running the following locally import healpy
import jax
import s2fft
import numpy
jax.config.update("jax_enable_x64", True)
seed = 20250416
nside = 4
L = 2 * nside
reality = False
rng = numpy.random.default_rng(seed)
flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
flm_hp = s2fft.sampling.s2_samples.flm_2d_to_hp(flm, L)
f = healpy.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
flm_cuda = s2fft.utils.healpix_ffts.healpix_fft_cuda(f=f, L=L, nside=nside, reality=reality).block_until_ready() raises an error
so it looks like there is some memory addressing issue somewhere in the |
Thank you I was able to reproduce with 12.4.1 but not locally with 12.4 I will take a look |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #290 +/- ##
==========================================
- Coverage 96.55% 96.07% -0.48%
==========================================
Files 32 32
Lines 3450 3469 +19
==========================================
+ Hits 3331 3333 +2
- Misses 119 136 +17 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@matt-graham Hey I would suggest dropping python3.8 from the test suite since JAX no longer supports it anyway |
Hi @ASKabalan. Do you mean
Yes agreed we should drop Python 3.8 from test matrix - we have an open pull request #305 to update to only supporting Python 3.11+ but this is partially blocked by #212 as the tests currently exit with fatal errors when running on MacOS / Python 3.9+ due to an incompatibility between the OpenMP runtime's the MacOS wheels for |
Add comprehensive documentation and fix dependency issues for CUDA FFT integration. This commit introduces extensive docstrings and inline comments across the C++ and Python codebase, particularly for the CUDA FFT implementation. It also addresses a dependency issue in to ensure proper installation and functionality. Key changes include: - no more CUDA Malloc .. all memory is allocated in Python by XLA - Added detailed docstrings to C++ header files - Enhanced inline comments in C++ source files to explain complex logic and algorithms. - Updated to relax JAX version dependency, resolving installation issues. - Refined docstrings and comments in Python files for clarity and consistency. - Cleaned up debug print statements
@matt-graham I fixed the issue with CUDA 12.4 and above |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ASKabalan. Thanks for the updates, this is looking great.
I have added some initial comments from a quick high level check.
I also tried running some of our benchmarks via the scripts in the benchmarks
directory to check everything is running as expected. Results are below (with method="jax"
corresponding to current native JAX implementation and method="jax_cuda"
using the CUDA primitive).
forward
(method: jax, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
min(run times): 0.0016s, max(run times): 0.0016s, compile time: 3.9s, peak memory: 1.1e+04B, max(abs(error)): 0.025, floating point ops: 1.2e+06, mem access: 4.3e+06B
(method: jax, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
min(run times): 0.0035s, max(run times): 0.0036s, compile time: 6.9s, peak memory: 1.1e+04B, max(abs(error)): 0.033, floating point ops: 5.1e+06, mem access: 1.9e+07B
(method: jax, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
min(run times): 0.010s, max(run times): 0.010s, compile time: 13.s, peak memory: 1.1e+04B, max(abs(error)): 0.032, floating point ops: 2.2e+07, mem access: 8.2e+07B
(method: jax, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
min(run times): 0.050s, max(run times): 0.050s, compile time: 26.s, peak memory: 1.0e+04B, max(abs(error)): 0.0095, floating point ops: 9.2e+07, mem access: 3.2e+08B
(method: jax_cuda, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
min(run times): 0.0015s, max(run times): 0.0015s, compile time: 0.73s, peak memory: 8.6e+03B, max(abs(error)): 0.036, floating point ops: 5.3e+05, mem access: 3.9e+06B
(method: jax_cuda, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
min(run times): 0.0032s, max(run times): 0.0033s, compile time: 0.87s, peak memory: 8.6e+03B, max(abs(error)): 0.67, floating point ops: 2.1e+06, mem access: 1.5e+07B
(method: jax_cuda, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
min(run times): 0.0096s, max(run times): 0.0096s, compile time: 0.92s, peak memory: 8.6e+03B, max(abs(error)): 0.61, floating point ops: 8.1e+06, mem access: 6.6e+07B
(method: jax_cuda, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
min(run times): 0.049s, max(run times): 0.049s, compile time: 1.3s, peak memory: 8.6e+03B, max(abs(error)): 0.71, floating point ops: 3.1e+07, mem access: 2.5e+08B
inverse
(method: jax, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
min(run times): 0.0021s, max(run times): 0.0021s, compile time: 5.3s, peak memory: 8.8e+03B, floating point ops: 1.2e+06, mem access: 4.5e+07B
(method: jax, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
min(run times): 0.0046s, max(run times): 0.0046s, compile time: 12.s, peak memory: 8.8e+03B, floating point ops: 5.3e+06, mem access: 3.5e+08B
(method: jax, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
min(run times): 0.011s, max(run times): 0.011s, compile time: 27.s, peak memory: 8.8e+03B, floating point ops: 2.2e+07, mem access: 2.7e+09B
(method: jax, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
min(run times): 0.056s, max(run times): 0.056s, compile time: 66.s, peak memory: 8.8e+03B, floating point ops: 9.4e+07, mem access: 2.2e+10B
(method: jax_cuda, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
min(run times): 0.0017s, max(run times): 0.0017s, compile time: 0.66s, peak memory: 8.6e+03B, floating point ops: 5.1e+05, mem access: 3.7e+06B
(method: jax_cuda, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
min(run times): 0.0037s, max(run times): 0.0038s, compile time: 0.67s, peak memory: 8.6e+03B, floating point ops: 2.0e+06, mem access: 1.5e+07B
(method: jax_cuda, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
min(run times): 0.0094s, max(run times): 0.0094s, compile time: 0.78s, peak memory: 8.6e+03B, floating point ops: 7.5e+06, mem access: 5.8e+07B
(method: jax_cuda, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
min(run times): 0.052s, max(run times): 0.052s, compile time: 0.88s, peak memory: 8.6e+03B, floating point ops: 3.0e+07, mem access: 2.3e+08B
In terms of the run and compilation times this all looks great - run times the same or slightly improved for a corresponding bandlimit L
and compilation times massively reduced (and close to constant in L
over the range tested). However, there seems to be something slightly odd going on with the round-trip errors (indicated by max(abs(error))
entries) which are significantly larger for the CUDA primitive version than for the native JAX implementation. For the HEALPix sampling scheme we would expect a non-negligible round-trip error, particularly without iterative refinement, but the size of the errors here seems to be too large and in particular I wouldn't expect there to be a significant difference in the errors compared to the native JAX version. Further when we use iterative refinement with 3 iterations the errors seem to be getting larger when using the CUDA primitive version rather than getting smaller as for the native JAX version.
forward
(method: jax_cuda, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: 3):
min(run times): 0.012s, max(run times): 0.012s, compile time: 2.4s, peak memory: 1.1e+04B, max(abs(error)): 49., floating point ops: 3.9e+06, mem access: 2.9e+07B
(method: jax_cuda, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: 3):
min(run times): 0.027s, max(run times): 0.027s, compile time: 2.6s, peak memory: 1.1e+04B, max(abs(error)): 68., floating point ops: 1.5e+07, mem access: 1.2e+08B
(method: jax_cuda, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: 3):
min(run times): 0.074s, max(run times): 0.074s, compile time: 2.7s, peak memory: 1.1e+04B, max(abs(error)): 1.1e+02, floating point ops: 5.5e+07, mem access: 4.9e+08B
(method: jax_cuda, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: 3):
min(run times): 0.39s, max(run times): 0.39s, compile time: 3.4s, peak memory: 1.0e+04B, max(abs(error)): 1.7e+02, floating point ops: 2.1e+08, mem access: 1.9e+09B
I haven't yet figured out what is causing this. The errors seem to be larger at higher bandlimits which might explain why the tests are not picking this up, but need to investigate this in more detail.
lib/include/kernel_helpers.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this file removed we can remove the comment in README
Lines 350 to 352 in d77e9cb
The file [`lib/include/kernel_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h) is adapted from | |
[code](https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281f3f5dbff579e8f5b00/lib/kernel_helpers.h) in [a tutorial on extending JAX](https://github.com/dfm/extending-jax) by | |
[Dan Foreman-Mackey](https://github.com/dfm) and licensed under a [MIT license](https://github.com/dfm/extending-jax/blob/371dca93c6405368fa8e71690afd3968d75f4bac/LICENSE). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this file removed we can remove comment in README
Lines 354 to 357 in d77e9cb
The file [`lib/include/kernel_nanobind_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_nanobind_helpers.h) | |
is adapted from [code](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/jaxlib/kernel_nanobind_helpers.h) | |
by the [JAX](https://github.com/jax-ml/jax) authors | |
and licensed under a [Apache-2.0 license](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/LICENSE). |
Co-authored-by: Matt Graham <[email protected]>
Co-authored-by: Matt Graham <[email protected]>
I have been trying to diagnose what is causing the numerical issues here. Not isolated the precise cause yet, but have somewhat narrowed things down.
The last point in particular makes me think this is something to do with unsafe memory access. I have been looking through the code but haven't spotted anything obvious so far. Unfortunately I also don't seem to be able to get the version before changes in this PR to run - while I can build the CUDA extension module, when running the tests on a GPU the process hangs indefinitely when reaching any calls to the custom primitives. |
After a bit more investigation I suspect the non-determinancy and inconsistency issues are arising due a race condition in the application of FFT shifting in the forward transform using the Specifically in s2fft/lib/src/s2fft_kernels.cu Lines 339 to 344 in c27dc7e
the (normalized) complex value stored in s2fft/lib/src/s2fft_kernels.cu Line 322 in c27dc7e
and so can safely write to s2fft/lib/src/s2fft_kernels.cu Line 335 in c27dc7e
if the Supporting this hypothesis - if we increase the block size to 1024 in s2fft/lib/src/s2fft_kernels.cu Line 438 in c27dc7e
thus decreasing the number of blocks in grid and so cross-block race conditions, the A simple but potentially memory inefficient solution would be to write out the shifted values to a different array rather than updating in-place. |
Adding a few updates
A batching rule seems to be very important for two things
Being able to jacrev/ jacfwd
and because in most cases .. the size of a healpix map can fit on a single GPU but sometimes we want to batch the spherical transform
I will be doing that next