Skip to content

Updating Healpix CUDA primitive #290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

Updating Healpix CUDA primitive #290

wants to merge 23 commits into from

Conversation

ASKabalan
Copy link
Collaborator

@ASKabalan ASKabalan commented Mar 26, 2025

Adding a few updates

  • Updating to the newest custom call API (API 4) using FFI
  • implementing a grad rule for healpix cuda FFT
  • Implementing a Batching rule

A batching rule seems to be very important for two things
Being able to jacrev/ jacfwd
and because in most cases .. the size of a healpix map can fit on a single GPU but sometimes we want to batch the spherical transform

I will be doing that next

@ASKabalan ASKabalan marked this pull request as draft March 26, 2025 16:25
@ASKabalan ASKabalan marked this pull request as ready for review March 28, 2025 16:08
@ASKabalan
Copy link
Collaborator Author

Hello @matt-graham @jasonmcewen @CosmoMatt

Just a quick PR to wrap up a few stuff

  1. Updated the binding API to the newest FFI
  2. Added a vmap implementation of the cuda primitive
  3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell)
  4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100
  5. Removed two files which are now no longer needed with the FFI API (kernel helpers) (so maybe they should be removed from the license section)
  6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982

And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)

There is an issue with building pyssht not sure that this is my fault

I will check the failing worflows when I get the chance, but in the meantime a review is appreciated

Copy link
Collaborator

@matt-graham matt-graham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @matt-graham @jasonmcewen @CosmoMatt

Just a quick PR to wrap up a few stuff

1. Updated the binding API to the newest [FFI](https://docs.jax.dev/en/latest/ffi.html)

2. Added a vmap implementation of the cuda primitive

3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell)

4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100

5. Removed two files which are now no longer needed with the FFI API ([kernel helpers](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h)) (so maybe they should be removed from the license section)

6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [[BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982](https://github.com/wjakob/nanobind/issues/982)

And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)

There is an issue with building pyssht not sure that this is my fault

I will check the failing worflows when I get the chance, but in the meantime a review is appreciated

Hi @ASKabalan, sorry for the delay in getting back to you.

This all sounds great - thanks for picking up #237 in particular and for the updates to use the newer FFI interface.

With regards to the failing workflows - this was probably due to #292 which was fixed in #293. If you merge in latest main here that should hopefully resolve the upstream dependency build problems that were causing the test workflows to fail.

I've added some initial review comments below. Will have a closer look next week and try testing this out, but don't have access to GPU machine atm.

Comment on lines 150 to 151
flm_hp = samples.flm_2d_to_hp(flm, L)
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use s2fft.inverse(flm, L=L, reality=False, method="jax", sampling="healpix") here instead of going via healpy? Rationale being that I would have a slight preference for minimising the number of additional tests that depend on healpy as it we are no longer requiring it as direct dependency for package and in the long run it might be possible to also remove it as a test dependency.

@matt-graham
Copy link
Collaborator

I've tried building, installing and running this on a system with CUDA 12.6 + a NVIDIA A100, and running the HEALPix FFT tests with

pytest tests/test_healpix_ffts.py

consistently the tests hang when trying to run the first test_healpix_fft_cuda instance.

Running just the IFFT tests with

pytest tests/test_healpix_ffts.py::test_healpix_ifft_cuda

the tests for both set of test parameters pass.

Trying to dig into this a bit, running the following locally

import healpy
import jax
import s2fft
import numpy

jax.config.update("jax_enable_x64", True)

seed = 20250416
nside = 4
L = 2 * nside
reality = False

rng = numpy.random.default_rng(seed)
flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
flm_hp = s2fft.sampling.s2_samples.flm_2d_to_hp(flm, L)
f = healpy.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
flm_cuda = s2fft.utils.healpix_ffts.healpix_fft_cuda(f=f, L=L, nside=nside, reality=reality).block_until_ready()

raises an error

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CUDA error: : CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

so it looks like there is some memory addressing issue somewhere in the healpix_fft_cuda implementation?

@ASKabalan
Copy link
Collaborator Author

Thank you

I was able to reproduce with 12.4.1 but not locally with 12.4

I will take a look

Copy link

codecov bot commented Jun 19, 2025

Codecov Report

Attention: Patch coverage is 75.00000% with 4 lines in your changes missing coverage. Please review.

Project coverage is 96.07%. Comparing base (0de6f11) to head (fb8d0df).

Files with missing lines Patch % Lines
s2fft/utils/healpix_ffts.py 66.66% 3 Missing ⚠️
s2fft/utils/jax_primitive.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #290      +/-   ##
==========================================
- Coverage   96.55%   96.07%   -0.48%     
==========================================
  Files          32       32              
  Lines        3450     3469      +19     
==========================================
+ Hits         3331     3333       +2     
- Misses        119      136      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ASKabalan
Copy link
Collaborator Author

@matt-graham Hey
I am picking were i left off
So it seems that there is an error when building with python3.8
Doesn't seem to be coming from my code.
It seems to be because of some compile error when compiling sht

I would suggest dropping python3.8 from the test suite since JAX no longer supports it anyway

@matt-graham
Copy link
Collaborator

@matt-graham Hey I am picking were i left off So it seems that there is an error when building with python3.8 Doesn't seem to be coming from my code. It seems to be because of some compile error when compiling sht

Hi @ASKabalan. Do you mean so3 rather than (py)ssht? From a quick look at the logs of the failing Actions workflow job on Python 3.8 / ubuntu-latest it appears like it's an error with building so3 (ERROR: Failed building wheel for so3). If so this is likely the same issue as described in #308. I've opened a PR to try to fix this upstream in so3 (astro-informatics/so3#31).

I would suggest dropping python3.8 from the test suite since JAX no longer supports it anyway

Yes agreed we should drop Python 3.8 from test matrix - we have an open pull request #305 to update to only supporting Python 3.11+ but this is partially blocked by #212 as the tests currently exit with fatal errors when running on MacOS / Python 3.9+ due to an incompatibility between the OpenMP runtime's the MacOS wheels for healpy and PyTorch are built for (healpy/healpy#1012)

Add comprehensive documentation and fix dependency issues for CUDA FFT integration.

This commit introduces extensive docstrings and inline comments across the C++ and Python codebase, particularly for the CUDA FFT implementation. It also addresses a dependency issue in  to ensure proper installation and functionality.

Key changes include:
- no more CUDA Malloc .. all memory is allocated in Python by XLA
- Added detailed docstrings to C++ header files
- Enhanced inline comments in C++ source files to explain complex logic and algorithms.
- Updated to relax JAX version dependency, resolving installation issues.
- Refined docstrings and comments in Python files for clarity and consistency.
- Cleaned up debug print statements
@ASKabalan ASKabalan marked this pull request as ready for review July 2, 2025 16:57
@ASKabalan
Copy link
Collaborator Author

@matt-graham I fixed the issue with CUDA 12.4 and above
I cleaned up the code and added docstrings everywhere
There is still a small issue with mac (since the JAX version is not the same)
But I think this is pretty mature to be merged

Copy link
Collaborator

@matt-graham matt-graham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ASKabalan. Thanks for the updates, this is looking great.

I have added some initial comments from a quick high level check.

I also tried running some of our benchmarks via the scripts in the benchmarks directory to check everything is running as expected. Results are below (with method="jax" corresponding to current native JAX implementation and method="jax_cuda" using the CUDA primitive).

forward
(method: jax, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
    min(run times):  0.0016s, max(run times):  0.0016s, compile time:     3.9s, peak memory: 1.1e+04B, max(abs(error)):   0.025, floating point ops: 1.2e+06, mem access: 4.3e+06B
(method: jax, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
    min(run times):  0.0035s, max(run times):  0.0036s, compile time:     6.9s, peak memory: 1.1e+04B, max(abs(error)):   0.033, floating point ops: 5.1e+06, mem access: 1.9e+07B
(method: jax, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
    min(run times):   0.010s, max(run times):   0.010s, compile time:     13.s, peak memory: 1.1e+04B, max(abs(error)):   0.032, floating point ops: 2.2e+07, mem access: 8.2e+07B
(method: jax, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
    min(run times):   0.050s, max(run times):   0.050s, compile time:     26.s, peak memory: 1.0e+04B, max(abs(error)):  0.0095, floating point ops: 9.2e+07, mem access: 3.2e+08B
(method: jax_cuda, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
    min(run times):  0.0015s, max(run times):  0.0015s, compile time:    0.73s, peak memory: 8.6e+03B, max(abs(error)):   0.036, floating point ops: 5.3e+05, mem access: 3.9e+06B
(method: jax_cuda, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
    min(run times):  0.0032s, max(run times):  0.0033s, compile time:    0.87s, peak memory: 8.6e+03B, max(abs(error)):    0.67, floating point ops: 2.1e+06, mem access: 1.5e+07B
(method: jax_cuda, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
    min(run times):  0.0096s, max(run times):  0.0096s, compile time:    0.92s, peak memory: 8.6e+03B, max(abs(error)):    0.61, floating point ops: 8.1e+06, mem access: 6.6e+07B
(method: jax_cuda, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: None):
    min(run times):   0.049s, max(run times):   0.049s, compile time:     1.3s, peak memory: 8.6e+03B, max(abs(error)):    0.71, floating point ops: 3.1e+07, mem access: 2.5e+08B
inverse
(method: jax, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
    min(run times):  0.0021s, max(run times):  0.0021s, compile time:     5.3s, peak memory: 8.8e+03B, floating point ops: 1.2e+06, mem access: 4.5e+07B
(method: jax, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
    min(run times):  0.0046s, max(run times):  0.0046s, compile time:     12.s, peak memory: 8.8e+03B, floating point ops: 5.3e+06, mem access: 3.5e+08B
(method: jax, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
    min(run times):   0.011s, max(run times):   0.011s, compile time:     27.s, peak memory: 8.8e+03B, floating point ops: 2.2e+07, mem access: 2.7e+09B
(method: jax, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
    min(run times):   0.056s, max(run times):   0.056s, compile time:     66.s, peak memory: 8.8e+03B, floating point ops: 9.4e+07, mem access: 2.2e+10B
(method: jax_cuda, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
    min(run times):  0.0017s, max(run times):  0.0017s, compile time:    0.66s, peak memory: 8.6e+03B, floating point ops: 5.1e+05, mem access: 3.7e+06B
(method: jax_cuda, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
    min(run times):  0.0037s, max(run times):  0.0038s, compile time:    0.67s, peak memory: 8.6e+03B, floating point ops: 2.0e+06, mem access: 1.5e+07B
(method: jax_cuda, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
    min(run times):  0.0094s, max(run times):  0.0094s, compile time:    0.78s, peak memory: 8.6e+03B, floating point ops: 7.5e+06, mem access: 5.8e+07B
(method: jax_cuda, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False):
    min(run times):   0.052s, max(run times):   0.052s, compile time:    0.88s, peak memory: 8.6e+03B, floating point ops: 3.0e+07, mem access: 2.3e+08B

In terms of the run and compilation times this all looks great - run times the same or slightly improved for a corresponding bandlimit L and compilation times massively reduced (and close to constant in L over the range tested). However, there seems to be something slightly odd going on with the round-trip errors (indicated by max(abs(error)) entries) which are significantly larger for the CUDA primitive version than for the native JAX implementation. For the HEALPix sampling scheme we would expect a non-negligible round-trip error, particularly without iterative refinement, but the size of the errors here seems to be too large and in particular I wouldn't expect there to be a significant difference in the errors compared to the native JAX version. Further when we use iterative refinement with 3 iterations the errors seem to be getting larger when using the CUDA primitive version rather than getting smaller as for the native JAX version.

forward
(method: jax_cuda, L: 64, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: 3):
    min(run times):   0.012s, max(run times):   0.012s, compile time:     2.4s, peak memory: 1.1e+04B, max(abs(error)):     49., floating point ops: 3.9e+06, mem access: 2.9e+07B
(method: jax_cuda, L: 128, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: 3):
    min(run times):   0.027s, max(run times):   0.027s, compile time:     2.6s, peak memory: 1.1e+04B, max(abs(error)):     68., floating point ops: 1.5e+07, mem access: 1.2e+08B
(method: jax_cuda, L: 256, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: 3):
    min(run times):   0.074s, max(run times):   0.074s, compile time:     2.7s, peak memory: 1.1e+04B, max(abs(error)): 1.1e+02, floating point ops: 5.5e+07, mem access: 4.9e+08B
(method: jax_cuda, L: 512, L_lower: 0, sampling: healpix, spin: 0, L_to_nside_ratio: 2, reality: True, spmd: False, n_iter: 3):
    min(run times):    0.39s, max(run times):    0.39s, compile time:     3.4s, peak memory: 1.0e+04B, max(abs(error)): 1.7e+02, floating point ops: 2.1e+08, mem access: 1.9e+09B

I haven't yet figured out what is causing this. The errors seem to be larger at higher bandlimits which might explain why the tests are not picking this up, but need to investigate this in more detail.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this file removed we can remove the comment in README

s2fft/README.md

Lines 350 to 352 in d77e9cb

The file [`lib/include/kernel_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h) is adapted from
[code](https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281f3f5dbff579e8f5b00/lib/kernel_helpers.h) in [a tutorial on extending JAX](https://github.com/dfm/extending-jax) by
[Dan Foreman-Mackey](https://github.com/dfm) and licensed under a [MIT license](https://github.com/dfm/extending-jax/blob/371dca93c6405368fa8e71690afd3968d75f4bac/LICENSE).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this file removed we can remove comment in README

s2fft/README.md

Lines 354 to 357 in d77e9cb

The file [`lib/include/kernel_nanobind_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_nanobind_helpers.h)
is adapted from [code](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/jaxlib/kernel_nanobind_helpers.h)
by the [JAX](https://github.com/jax-ml/jax) authors
and licensed under a [Apache-2.0 license](https://github.com/jax-ml/jax/blob/3d389a7fb440c412d95a1f70ffb91d58408247d0/LICENSE).

ASKabalan and others added 2 commits July 10, 2025 07:58
Co-authored-by: Matt Graham <[email protected]>
Co-authored-by: Matt Graham <[email protected]>
@matt-graham
Copy link
Collaborator

matt-graham commented Jul 25, 2025

I have been trying to diagnose what is causing the numerical issues here. Not isolated the precise cause yet, but have somewhat narrowed things down.

  • This only seems to affect the forward tranform s2fft.utils.healpix_ffts.healpix_fft_cuda and not the backward transform s2fft.utils.healpix_ffts.healpix_ifft_cuda. Changing nsides_to_test in tests/test_healpix_fft.py to nside_to_test = list(range(2, 16)) + [16, 32, 64] and running pytest tests/test_healpix_ffts.py::test_healpix_ifft_cuda, the tests all pass consistently over multiple tries.
  • For s2fft.utils.healpix_ffts.healpix_fft_cuda, there are localised differences in output compared to s2fft.utils.healpix_ffts.healpix_fft_jax for specific nside / L values, with differences seeming to become more likely for larger nside. For example for nside = 12 we get that indices that the outputs differ in are
    >>> np.nonzero((abs(ftm_jax - ftm_cuda) > 1e-7))
    (array([10, 10, 10, 10, 10, 10, 10, 10, 21, 21, 21, 21, 21, 21, 21, 21]), 
     array([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]))
    
  • There seems to be some non-determinism in whether there are differences or not when running pytest tests/test_healpix_ffts.py::test_healpix_fft_cuda for given nside values, despite all input data being the same (the tests used a fixed random seed to generate data). For example for consecutive test runs I got
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-2] PASSED                                                                                                                                                       [  5%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-3] PASSED                                                                                                                                                       [ 11%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-4] PASSED                                                                                                                                                       [ 17%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-5] FAILED                                                                                                                                                       [ 23%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-6] PASSED                                                                                                                                                       [ 29%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-7] PASSED                                                                                                                                                       [ 35%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-8] PASSED                                                                                                                                                       [ 41%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-9] PASSED                                                                                                                                                       [ 47%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-10] PASSED                                                                                                                                                      [ 52%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-11] PASSED                                                                                                                                                      [ 58%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-12] FAILED                                                                                                                                                      [ 64%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-13] PASSED                                                                                                                                                      [ 70%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-14] FAILED                                                                                                                                                      [ 76%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-15] FAILED                                                                                                                                                      [ 82%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-16] FAILED                                                                                                                                                      [ 88%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-32] FAILED                                                                                                                                                      [ 94%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-64] FAILED                                                                                                                                                      [100%]
    
    and then
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-2] PASSED                                                                                                                                                       [  5%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-3] PASSED                                                                                                                                                       [ 11%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-4] PASSED                                                                                                                                                       [ 17%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-5] PASSED                                                                                                                                                       [ 23%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-6] PASSED                                                                                                                                                       [ 29%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-7] PASSED                                                                                                                                                       [ 35%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-8] PASSED                                                                                                                                                       [ 41%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-9] PASSED                                                                                                                                                       [ 47%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-10] PASSED                                                                                                                                                      [ 52%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-11] PASSED                                                                                                                                                      [ 58%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-12] FAILED                                                                                                                                                      [ 64%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-13] FAILED                                                                                                                                                      [ 70%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-14] PASSED                                                                                                                                                      [ 76%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-15] FAILED                                                                                                                                                      [ 82%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-16] PASSED                                                                                                                                                      [ 88%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-32] FAILED                                                                                                                                                      [ 94%]
    tests/test_healpix_ffts.py::test_healpix_fft_cuda[8966433580120847635-64] FAILED                                                                                                                                                      [100%]
    
    with for example nside=5 failing in the first run and not the second, and nside=13 failing in the second run but not the first.

The last point in particular makes me think this is something to do with unsafe memory access. I have been looking through the code but haven't spotted anything obvious so far. Unfortunately I also don't seem to be able to get the version before changes in this PR to run - while I can build the CUDA extension module, when running the tests on a GPU the process hangs indefinitely when reaching any calls to the custom primitives.

@matt-graham
Copy link
Collaborator

matt-graham commented Jul 28, 2025

After a bit more investigation I suspect the non-determinancy and inconsistency issues are arising due a race condition in the application of FFT shifting in the forward transform using the shift_normalize_kernel in s2fft_kernels.cu.

Specifically in

// Step 4a: Compute shifted position within ring
long long int shifted_o = (o + nphi / 2) % nphi;
shifted_o = shifted_o < 0 ? nphi + shifted_o : shifted_o;
long long int dest_p = r_start + shifted_o;
// printf(" -> CUDA: Applying shift: p=%lld, dest_p=%lld, shifted_o=%lld\n", p, dest_p, shifted_o);
data[dest_p] = element;

the (normalized) complex value stored in element is written to index dest_p of the data array. While within a block, as each thread reads data[p] in to element earlier in

complex element = data[p];

and so can safely write to dest_p indices corresponding to other threads within the same block due to the thread synchronisation operation between the read in to element and write to data[dest_p],

__syncthreads(); // Ensure all threads have completed normalization

if the dest_p indices instead maps to a pixel index corresponding to p in a different block, then writing to data[dest_p] may occur before the threads in that block have read dest[p] into element, meaning the final values in data will depend on the block execution order.

Supporting this hypothesis - if we increase the block size to 1024 in launch_shift_normalize_kernel in

int block_size = 256;

thus decreasing the number of blocks in grid and so cross-block race conditions, the test_healpix_ffts.py::test_healpix_fft_cuda test passes for all nside in list(range(2, 16)) + [16, 32, 64] consistently, but still fails sometimes if we set nside large enough - for example I got failures for nside = 256 with high likelihood. Conversely changing the block size to be smaller, for example block_size = 32, we start getting failures even for the smallest nside values.

A simple but potentially memory inefficient solution would be to write out the shifted values to a different array rather than updating in-place.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Check autodiff and batching support for healpix_fft_cuda primitive and add if needed
2 participants