Skip to content

Commit 9775bba

Browse files
committed
remvove callback params workspace
1 parent 00b169c commit 9775bba

File tree

2 files changed

+18
-41
lines changed

2 files changed

+18
-41
lines changed

lib/src/extensions.cc

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,12 @@ constexpr bool is_double_v = is_double<T>::value;
6565
* @param input Input buffer containing HEALPix pixel-space data.
6666
* @param output Output buffer to store the FTM result.
6767
* @param workspace Output buffer for temporary workspace memory.
68-
* @param callback_params Output buffer for callback parameters.
6968
* @param descriptor Descriptor containing transform parameters.
7069
* @return ffi::Error indicating success or failure.
7170
*/
7271
template <ffi::DataType T>
7372
ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
7473
ffi::Result<ffi::Buffer<T>> workspace,
75-
ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params,
7674
s2fftDescriptor descriptor) {
7775
// Step 1: Determine the complex type based on the XLA data type.
7876
using fft_complex_type = fft_complex_t<T>;
@@ -82,10 +80,9 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
8280
if (dim_in.size() == 2) {
8381
// Step 2a: Batched case.
8482
int batch_count = dim_in[0];
85-
// Step 2b: Compute offsets for input, output, and callback parameters for each batch.
83+
// Step 2b: Compute offsets for input and output for each batch.
8684
int64_t input_offset = descriptor.nside * descriptor.nside * 12;
8785
int64_t output_offset = (4 * descriptor.nside - 1) * (2 * descriptor.harmonic_band_limit);
88-
int64_t params_offset = 2 * (descriptor.nside - 1) + 1;
8986

9087
// Step 2c: Fork CUDA streams for parallel processing of batches.
9188
CudaStreamHandler handler;
@@ -99,16 +96,13 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
9996
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
10097
PlanCache::GetInstance().GetS2FFTExec(descriptor, executor);
10198

102-
// Step 2f: Calculate device pointers for the current batch's data, output, workspace, and
103-
// callback parameters.
99+
// Step 2f: Calculate device pointers for the current batch's data, output, and workspace.
104100
fft_complex_type* data_c =
105101
reinterpret_cast<fft_complex_type*>(input.typed_data() + i * input_offset);
106102
fft_complex_type* out_c =
107103
reinterpret_cast<fft_complex_type*>(output->typed_data() + i * output_offset);
108104
fft_complex_type* workspace_c =
109105
reinterpret_cast<fft_complex_type*>(workspace->typed_data() + i * executor->m_work_size);
110-
int64* callback_params_c =
111-
reinterpret_cast<int64*>(callback_params->typed_data() + i * params_offset);
112106

113107
// Step 2g: Launch the forward transform on this sub-stream.
114108
executor->Forward(descriptor, sub_stream, data_c, workspace_c);
@@ -121,11 +115,10 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
121115
return ffi::Error::Success();
122116
} else {
123117
// Step 2j: Non-batched case.
124-
// Step 2k: Get device pointers for data, output, workspace, and callback parameters.
118+
// Step 2k: Get device pointers for data, output, and workspace.
125119
fft_complex_type* data_c = reinterpret_cast<fft_complex_type*>(input.typed_data());
126120
fft_complex_type* out_c = reinterpret_cast<fft_complex_type*>(output->typed_data());
127121
fft_complex_type* workspace_c = reinterpret_cast<fft_complex_type*>(workspace->typed_data());
128-
int64* callback_params_c = reinterpret_cast<int64*>(callback_params->typed_data());
129122

130123
// Step 2l: Get or create an s2fftExec instance from the PlanCache.
131124
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
@@ -152,14 +145,12 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
152145
* @param input Input buffer containing FTM data.
153146
* @param output Output buffer to store HEALPix pixel-space data.
154147
* @param workspace Output buffer for temporary workspace memory.
155-
* @param callback_params Output buffer for callback parameters.
156148
* @param descriptor Descriptor containing transform parameters.
157149
* @return ffi::Error indicating success or failure.
158150
*/
159151
template <ffi::DataType T>
160152
ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
161153
ffi::Result<ffi::Buffer<T>> workspace,
162-
ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params,
163154
s2fftDescriptor descriptor) {
164155
// Step 1: Determine the complex type based on the XLA data type.
165156
using fft_complex_type = fft_complex_t<T>;
@@ -189,16 +180,13 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
189180
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
190181
PlanCache::GetInstance().GetS2FFTExec(descriptor, executor);
191182

192-
// Step 2f: Calculate device pointers for the current batch's data, output, workspace, and
193-
// callback parameters.
183+
// Step 2f: Calculate device pointers for the current batch's data, output, and workspace.
194184
fft_complex_type* data_c =
195185
reinterpret_cast<fft_complex_type*>(input.typed_data() + i * input_offset);
196186
fft_complex_type* out_c =
197187
reinterpret_cast<fft_complex_type*>(output->typed_data() + i * output_offset);
198188
fft_complex_type* workspace_c =
199189
reinterpret_cast<fft_complex_type*>(workspace->typed_data() + i * executor->m_work_size);
200-
int64* callback_params_c =
201-
reinterpret_cast<int64*>(callback_params->typed_data() + i * sizeof(int64) * 2);
202190

203191
// Step 2g: Launch spectral folding kernel.
204192
s2fftKernels::launch_spectral_folding(data_c, out_c, descriptor.nside,
@@ -215,11 +203,10 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
215203
// Assertions to ensure correct input/output dimensions for non-batched operations.
216204
assert(dim_in.size() == 2);
217205
assert(dim_out.size() == 1);
218-
// Step 2k: Get device pointers for data, output, workspace, and callback parameters.
206+
// Step 2k: Get device pointers for data, output, and workspace.
219207
fft_complex_type* data_c = reinterpret_cast<fft_complex_type*>(input.typed_data());
220208
fft_complex_type* out_c = reinterpret_cast<fft_complex_type*>(output->typed_data());
221209
fft_complex_type* workspace_c = reinterpret_cast<fft_complex_type*>(workspace->typed_data());
222-
int64* callback_params_c = reinterpret_cast<int64*>(callback_params->typed_data());
223210

224211
// Step 2l: Get or create an s2fftExec instance from the PlanCache.
225212
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
@@ -310,24 +297,22 @@ s2fftDescriptor build_descriptor(int64_t nside, int64_t harmonic_band_limit, boo
310297
* @param input Input buffer.
311298
* @param output Output buffer.
312299
* @param workspace Output buffer for temporary workspace memory.
313-
* @param callback_params Output buffer for callback parameters.
314300
* @return ffi::Error indicating success or failure.
315301
*/
316302
template <ffi::DataType T>
317303
ffi::Error healpix_fft_cuda(cudaStream_t stream, int64_t nside, int64_t harmonic_band_limit, bool reality,
318304
bool forward, bool normalize, bool adjoint, ffi::Buffer<T> input,
319-
ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace,
320-
ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params) {
305+
ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace) {
321306
// Step 1: Build the s2fftDescriptor based on the input parameters.
322307
size_t work_size = 0; // Variable to hold the workspace size
323308
s2fftDescriptor descriptor = build_descriptor<T>(nside, harmonic_band_limit, reality, forward, normalize,
324309
adjoint, true, work_size);
325310

326311
// Step 2: Dispatch to either forward or backward transform based on the 'forward' flag.
327312
if (forward) {
328-
return healpix_forward<T>(stream, input, output, workspace, callback_params, descriptor);
313+
return healpix_forward<T>(stream, input, output, workspace, descriptor);
329314
} else {
330-
return healpix_backward<T>(stream, input, output, workspace, callback_params, descriptor);
315+
return healpix_backward<T>(stream, input, output, workspace, descriptor);
331316
}
332317
}
333318

@@ -348,8 +333,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C64, healpix_fft_cuda<ffi::DataTy
348333
.Attr<bool>("adjoint")
349334
.Arg<ffi::Buffer<ffi::DataType::C64>>()
350335
.Ret<ffi::Buffer<ffi::DataType::C64>>()
351-
.Ret<ffi::Buffer<ffi::DataType::C64>>()
352-
.Ret<ffi::Buffer<ffi::DataType::S64>>());
336+
.Ret<ffi::Buffer<ffi::DataType::C64>>());
353337

354338
XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C128, healpix_fft_cuda<ffi::DataType::C128>,
355339
ffi::Ffi::Bind()
@@ -362,8 +346,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C128, healpix_fft_cuda<ffi::DataT
362346
.Attr<bool>("adjoint")
363347
.Arg<ffi::Buffer<ffi::DataType::C128>>()
364348
.Ret<ffi::Buffer<ffi::DataType::C128>>()
365-
.Ret<ffi::Buffer<ffi::DataType::C128>>()
366-
.Ret<ffi::Buffer<ffi::DataType::S64>>());
349+
.Ret<ffi::Buffer<ffi::DataType::C128>>());
367350

368351
/**
369352
* @brief Encapsulates an FFI handler into a nanobind capsule.

s2fft/utils/healpix_ffts.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,7 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint):
609609
worksize //= 8 # 8 bytes per C64 element
610610
workspace_shape = (worksize,)
611611
workspace_dtype = np.complex64
612-
# Step 3: Calculate shape for callback parameters.
613-
nb_params = 2 * (nside - 1) + 1
614-
params_shape = (nb_params,)
615-
616-
# Step 4: Define output shapes based on FFT type.
612+
# Step 3: Define output shapes based on FFT type.
617613
healpix_size = (nside**2 * 12,)
618614
ftm_size = (4 * nside - 1, 2 * L)
619615
if fft_type == "forward":
@@ -628,17 +624,15 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint):
628624
else:
629625
raise ValueError(f"fft_type {fft_type} not recognised.")
630626

631-
# Step 5: Create ShapedArray objects for output, workspace, and callback parameters.
627+
# Step 4: Create ShapedArray objects for output and workspace.
632628
workspace_aval = ShapedArray(
633629
shape=batch_shape + workspace_shape, dtype=workspace_dtype
634630
)
635-
params_eval = ShapedArray(shape=batch_shape + params_shape, dtype=np.int64)
636631

637-
# Step 6: Return the ShapedArray objects.
632+
# Step 5: Return the ShapedArray objects.
638633
return (
639634
f.update(shape=out_shape, dtype=f.dtype),
640635
workspace_aval,
641-
params_eval,
642636
)
643637

644638

@@ -674,7 +668,7 @@ def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm, adj
674668
raise MissingCUDASupport()
675669

676670
# Step 2: Get the abstract evaluation results for the outputs.
677-
(aval_out, _, _) = ctx.avals_out
671+
(aval_out, _) = ctx.avals_out
678672

679673
# Step 3: Get lowering information (double precision, forward/backward, normalize).
680674
is_double, forward, normalize = _get_lowering_info(fft_type, norm, aval_out.dtype)
@@ -839,8 +833,8 @@ def healpix_fft_cuda(
839833
"""
840834
# Step 1: Promote input data to complex dtype if necessary.
841835
(f,) = promote_dtypes_complex(f)
842-
# Step 2: Bind the input to the CUDA primitive. It returns multiple outputs (out, workspace, callback_params).
843-
out, _, _ = _healpix_fft_cuda_primitive.bind(
836+
# Step 2: Bind the input to the CUDA primitive. It returns multiple outputs (out, workspace).
837+
out, _ = _healpix_fft_cuda_primitive.bind(
844838
f,
845839
L=L,
846840
nside=nside,
@@ -879,8 +873,8 @@ def healpix_ifft_cuda(
879873
"""
880874
# Step 1: Promote input data to complex dtype if necessary.
881875
(ftm,) = promote_dtypes_complex(ftm)
882-
# Step 2: Bind the input to the CUDA primitive. It returns multiple outputs (out, workspace, callback_params).
883-
out, _, _ = _healpix_fft_cuda_primitive.bind(
876+
# Step 2: Bind the input to the CUDA primitive. It returns multiple outputs (out, workspace).
877+
out, _ = _healpix_fft_cuda_primitive.bind(
884878
ftm,
885879
L=L,
886880
nside=nside,

0 commit comments

Comments
 (0)