@@ -65,14 +65,12 @@ constexpr bool is_double_v = is_double<T>::value;
65
65
* @param input Input buffer containing HEALPix pixel-space data.
66
66
* @param output Output buffer to store the FTM result.
67
67
* @param workspace Output buffer for temporary workspace memory.
68
- * @param callback_params Output buffer for callback parameters.
69
68
* @param descriptor Descriptor containing transform parameters.
70
69
* @return ffi::Error indicating success or failure.
71
70
*/
72
71
template <ffi::DataType T>
73
72
ffi::Error healpix_forward (cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
74
73
ffi::Result<ffi::Buffer<T>> workspace,
75
- ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params,
76
74
s2fftDescriptor descriptor) {
77
75
// Step 1: Determine the complex type based on the XLA data type.
78
76
using fft_complex_type = fft_complex_t <T>;
@@ -82,10 +80,9 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
82
80
if (dim_in.size () == 2 ) {
83
81
// Step 2a: Batched case.
84
82
int batch_count = dim_in[0 ];
85
- // Step 2b: Compute offsets for input, output, and callback parameters for each batch.
83
+ // Step 2b: Compute offsets for input and output for each batch.
86
84
int64_t input_offset = descriptor.nside * descriptor.nside * 12 ;
87
85
int64_t output_offset = (4 * descriptor.nside - 1 ) * (2 * descriptor.harmonic_band_limit );
88
- int64_t params_offset = 2 * (descriptor.nside - 1 ) + 1 ;
89
86
90
87
// Step 2c: Fork CUDA streams for parallel processing of batches.
91
88
CudaStreamHandler handler;
@@ -99,16 +96,13 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
99
96
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
100
97
PlanCache::GetInstance ().GetS2FFTExec (descriptor, executor);
101
98
102
- // Step 2f: Calculate device pointers for the current batch's data, output, workspace, and
103
- // callback parameters.
99
+ // Step 2f: Calculate device pointers for the current batch's data, output, and workspace.
104
100
fft_complex_type* data_c =
105
101
reinterpret_cast <fft_complex_type*>(input.typed_data () + i * input_offset);
106
102
fft_complex_type* out_c =
107
103
reinterpret_cast <fft_complex_type*>(output->typed_data () + i * output_offset);
108
104
fft_complex_type* workspace_c =
109
105
reinterpret_cast <fft_complex_type*>(workspace->typed_data () + i * executor->m_work_size );
110
- int64* callback_params_c =
111
- reinterpret_cast <int64*>(callback_params->typed_data () + i * params_offset);
112
106
113
107
// Step 2g: Launch the forward transform on this sub-stream.
114
108
executor->Forward (descriptor, sub_stream, data_c, workspace_c);
@@ -121,11 +115,10 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
121
115
return ffi::Error::Success ();
122
116
} else {
123
117
// Step 2j: Non-batched case.
124
- // Step 2k: Get device pointers for data, output, workspace, and callback parameters .
118
+ // Step 2k: Get device pointers for data, output, and workspace .
125
119
fft_complex_type* data_c = reinterpret_cast <fft_complex_type*>(input.typed_data ());
126
120
fft_complex_type* out_c = reinterpret_cast <fft_complex_type*>(output->typed_data ());
127
121
fft_complex_type* workspace_c = reinterpret_cast <fft_complex_type*>(workspace->typed_data ());
128
- int64* callback_params_c = reinterpret_cast <int64*>(callback_params->typed_data ());
129
122
130
123
// Step 2l: Get or create an s2fftExec instance from the PlanCache.
131
124
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
@@ -152,14 +145,12 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
152
145
* @param input Input buffer containing FTM data.
153
146
* @param output Output buffer to store HEALPix pixel-space data.
154
147
* @param workspace Output buffer for temporary workspace memory.
155
- * @param callback_params Output buffer for callback parameters.
156
148
* @param descriptor Descriptor containing transform parameters.
157
149
* @return ffi::Error indicating success or failure.
158
150
*/
159
151
template <ffi::DataType T>
160
152
ffi::Error healpix_backward (cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
161
153
ffi::Result<ffi::Buffer<T>> workspace,
162
- ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params,
163
154
s2fftDescriptor descriptor) {
164
155
// Step 1: Determine the complex type based on the XLA data type.
165
156
using fft_complex_type = fft_complex_t <T>;
@@ -189,16 +180,13 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
189
180
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
190
181
PlanCache::GetInstance ().GetS2FFTExec (descriptor, executor);
191
182
192
- // Step 2f: Calculate device pointers for the current batch's data, output, workspace, and
193
- // callback parameters.
183
+ // Step 2f: Calculate device pointers for the current batch's data, output, and workspace.
194
184
fft_complex_type* data_c =
195
185
reinterpret_cast <fft_complex_type*>(input.typed_data () + i * input_offset);
196
186
fft_complex_type* out_c =
197
187
reinterpret_cast <fft_complex_type*>(output->typed_data () + i * output_offset);
198
188
fft_complex_type* workspace_c =
199
189
reinterpret_cast <fft_complex_type*>(workspace->typed_data () + i * executor->m_work_size );
200
- int64* callback_params_c =
201
- reinterpret_cast <int64*>(callback_params->typed_data () + i * sizeof (int64) * 2 );
202
190
203
191
// Step 2g: Launch spectral folding kernel.
204
192
s2fftKernels::launch_spectral_folding (data_c, out_c, descriptor.nside ,
@@ -215,11 +203,10 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
215
203
// Assertions to ensure correct input/output dimensions for non-batched operations.
216
204
assert (dim_in.size () == 2 );
217
205
assert (dim_out.size () == 1 );
218
- // Step 2k: Get device pointers for data, output, workspace, and callback parameters .
206
+ // Step 2k: Get device pointers for data, output, and workspace .
219
207
fft_complex_type* data_c = reinterpret_cast <fft_complex_type*>(input.typed_data ());
220
208
fft_complex_type* out_c = reinterpret_cast <fft_complex_type*>(output->typed_data ());
221
209
fft_complex_type* workspace_c = reinterpret_cast <fft_complex_type*>(workspace->typed_data ());
222
- int64* callback_params_c = reinterpret_cast <int64*>(callback_params->typed_data ());
223
210
224
211
// Step 2l: Get or create an s2fftExec instance from the PlanCache.
225
212
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
@@ -310,24 +297,22 @@ s2fftDescriptor build_descriptor(int64_t nside, int64_t harmonic_band_limit, boo
310
297
* @param input Input buffer.
311
298
* @param output Output buffer.
312
299
* @param workspace Output buffer for temporary workspace memory.
313
- * @param callback_params Output buffer for callback parameters.
314
300
* @return ffi::Error indicating success or failure.
315
301
*/
316
302
template <ffi::DataType T>
317
303
ffi::Error healpix_fft_cuda (cudaStream_t stream, int64_t nside, int64_t harmonic_band_limit, bool reality,
318
304
bool forward, bool normalize, bool adjoint, ffi::Buffer<T> input,
319
- ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace,
320
- ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params) {
305
+ ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace) {
321
306
// Step 1: Build the s2fftDescriptor based on the input parameters.
322
307
size_t work_size = 0 ; // Variable to hold the workspace size
323
308
s2fftDescriptor descriptor = build_descriptor<T>(nside, harmonic_band_limit, reality, forward, normalize,
324
309
adjoint, true , work_size);
325
310
326
311
// Step 2: Dispatch to either forward or backward transform based on the 'forward' flag.
327
312
if (forward) {
328
- return healpix_forward<T>(stream, input, output, workspace, callback_params, descriptor);
313
+ return healpix_forward<T>(stream, input, output, workspace, descriptor);
329
314
} else {
330
- return healpix_backward<T>(stream, input, output, workspace, callback_params, descriptor);
315
+ return healpix_backward<T>(stream, input, output, workspace, descriptor);
331
316
}
332
317
}
333
318
@@ -348,8 +333,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C64, healpix_fft_cuda<ffi::DataTy
348
333
.Attr<bool>(" adjoint" )
349
334
.Arg<ffi::Buffer<ffi::DataType::C64>>()
350
335
.Ret<ffi::Buffer<ffi::DataType::C64>>()
351
- .Ret<ffi::Buffer<ffi::DataType::C64>>()
352
- .Ret<ffi::Buffer<ffi::DataType::S64>>());
336
+ .Ret<ffi::Buffer<ffi::DataType::C64>>());
353
337
354
338
XLA_FFI_DEFINE_HANDLER_SYMBOL (healpix_fft_cuda_C128, healpix_fft_cuda<ffi::DataType::C128>,
355
339
ffi::Ffi::Bind ()
@@ -362,8 +346,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C128, healpix_fft_cuda<ffi::DataT
362
346
.Attr<bool>(" adjoint" )
363
347
.Arg<ffi::Buffer<ffi::DataType::C128>>()
364
348
.Ret<ffi::Buffer<ffi::DataType::C128>>()
365
- .Ret<ffi::Buffer<ffi::DataType::C128>>()
366
- .Ret<ffi::Buffer<ffi::DataType::S64>>());
349
+ .Ret<ffi::Buffer<ffi::DataType::C128>>());
367
350
368
351
/* *
369
352
* @brief Encapsulates an FFI handler into a nanobind capsule.
0 commit comments