Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ repos:
hooks:
- id: ruff
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
hooks:
- id: clang-format
files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$'
exclude: '^third_party/|/pybind11/'
name: clang-format
95 changes: 60 additions & 35 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ set(CMAKE_CUDA_STANDARD 17)

# Set default build type to Release
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
set(CMAKE_BUILD_TYPE
Release
CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release"
"MinSizeRel" "RelWithDebInfo")
endif()

# Check for CUDA
Expand All @@ -23,35 +26,48 @@ if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA compiler found: ${CMAKE_CUDA_COMPILER}")

if(NOT SKBUILD)
message(FATAL_ERROR "Building standalone project directly without pip install is not supported"
"Please use pip install to build the project")
message(
FATAL_ERROR
"Building standalone project directly without pip install is not supported"
"Please use pip install to build the project")
else()
find_package(CUDAToolkit REQUIRED)

find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
# Add the executable
find_package(
Python 3.8 REQUIRED
COMPONENTS Interpreter Development.Module
OPTIONAL_COMPONENTS Development.SABIModule)
execute_process(
COMMAND "${Python_EXECUTABLE}" "-c"
"from jax import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")

# Detect the installed nanobind package and import it into CMake
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)

nanobind_add_module(_s2fft STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft.cu
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_callbacks.cu
${CMAKE_CURRENT_LIST_DIR}/lib/src/plan_cache.cc
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_kernels.cu
)
find_package(nanobind CONFIG REQUIRED)

nanobind_add_module(
_s2fft
STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft.cu
${CMAKE_CURRENT_LIST_DIR}/lib/src/plan_cache.cc
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_kernels.cu)

target_link_libraries(_s2fft PRIVATE CUDA::cudart_static CUDA::cufft_static CUDA::culibos)
target_include_directories(_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include)
set_target_properties(_s2fft PROPERTIES
LINKER_LANGUAGE CUDA
CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_CUDA_ARCHITECTURES "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
target_include_directories(
_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include ${XLA_DIR} ${CUDAToolkit_INCLUDE_DIRS})
set_target_properties(_s2fft PROPERTIES LINKER_LANGUAGE CUDA
CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -rdc=true")
set(CMAKE_CUDA_ARCHITECTURES
"70;80;89"
CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
message(STATUS "CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
set_target_properties(_s2fft PROPERTIES CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES}")
set_target_properties(_s2fft PROPERTIES CUDA_ARCHITECTURES
"${CMAKE_CUDA_ARCHITECTURES}")

install(TARGETS _s2fft LIBRARY DESTINATION s2fft_lib)
endif()
Expand All @@ -60,26 +76,35 @@ else()
if(SKBUILD)
message(WARNING "CUDA compiler not found, building without CUDA support")

find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)

# Add the executable
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
COMMAND "${Python_EXECUTABLE}" "-c"
"from jax import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")

nanobind_add_module(_s2fft STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc
)
# Detect the installed nanobind package and import it into CMake
find_package(nanobind CONFIG REQUIRED)

nanobind_add_module(_s2fft STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc)

target_compile_definitions(_s2fft PRIVATE NO_CUDA_COMPILER)
target_include_directories(_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include)
target_include_directories(
_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include ${XLA_DIR})

install(TARGETS _s2fft LIBRARY DESTINATION s2fft_lib)

else()
message(FATAL_ERROR "Building standalone project directly without pip install is not supported"
"Please use pip install to build the project")
message(
FATAL_ERROR
"Building standalone project directly without pip install is not supported"
"Please use pip install to build the project")
endif()
endif()


165 changes: 165 additions & 0 deletions lib/include/cudastreamhandler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@

/**
* @file cudastreamhandler.hpp
* @brief Singleton class for managing CUDA streams and events.
*
* This header provides a singleton implementation that encapsulates the creation,
* management, and cleanup of CUDA streams and events. It offers functions to fork
* streams, add new streams, and synchronize (join) streams with a given dependency.
*
* Usage example:
* @code
* #include "cudastreamhandler.hpp"
*
* int main() {
* // Create a handler instance
* CudaStreamHandler handler;
*
* // Fork 4 streams dependent on a given stream 'stream_main'
* handler.Fork(stream_main, 4);
*
* // Do work on the forked streams...
*
* // Join the streams back to 'stream_main'
* handler.join(stream_main);
*
* return 0;
* }
* @endcode
*
* Author: Wassim KABALAN
*/

#ifndef CUDASTREAMHANDLER_HPP
#define CUDASTREAMHANDLER_HPP

#include <algorithm>
#include <atomic>
#include <cuda_runtime.h>
#include <stdexcept>
#include <thread>
#include <vector>

// Singleton class managing CUDA streams and events
class CudaStreamHandlerImpl {
public:
static CudaStreamHandlerImpl &instance() {
static CudaStreamHandlerImpl instance;
return instance;
}

void AddStreams(int numStreams) {
if (numStreams > m_streams.size()) {
int streamsToAdd = numStreams - m_streams.size();
m_streams.resize(numStreams);
std::generate(m_streams.end() - streamsToAdd, m_streams.end(), []() {
cudaStream_t stream;
cudaStreamCreate(&stream);
return stream;
});
}
}

void join(cudaStream_t finalStream) {
std::for_each(m_streams.begin(), m_streams.end(), [this, finalStream](cudaStream_t stream) {
cudaEvent_t event;
cudaEventCreate(&event);
cudaEventRecord(event, stream);
cudaStreamWaitEvent(finalStream, event, 0);
m_events.push_back(event);
});

if (!cleanup_thread.joinable()) {
stop_thread.store(false);
cleanup_thread = std::thread([this]() { this->AsyncEventCleanup(); });
}
}

// Fork function to add streams and set dependency on a given stream
void Fork(cudaStream_t dependentStream, int N) {
AddStreams(N); // Add N streams

// Set dependency on the provided stream
std::for_each(m_streams.end() - N, m_streams.end(), [this, dependentStream](cudaStream_t stream) {
cudaEvent_t event;
cudaEventCreate(&event);
cudaEventRecord(event, dependentStream);
cudaStreamWaitEvent(stream, event, 0); // Set the stream to wait on the event
m_events.push_back(event);
});
}

auto getIterator() { return StreamIterator(m_streams.begin(), m_streams.end()); }

~CudaStreamHandlerImpl() {
stop_thread.store(true);
if (cleanup_thread.joinable()) {
cleanup_thread.join();
}

std::for_each(m_streams.begin(), m_streams.end(), cudaStreamDestroy);
std::for_each(m_events.begin(), m_events.end(), cudaEventDestroy);
}

// Custom Iterator class to iterate over streams
class StreamIterator {
public:
StreamIterator(std::vector<cudaStream_t>::iterator begin, std::vector<cudaStream_t>::iterator end)
: current(begin), end(end) {}

cudaStream_t next() {
if (current == end) {
throw std::out_of_range("No more streams.");
}
return *current++;
}

bool hasNext() const { return current != end; }

private:
std::vector<cudaStream_t>::iterator current;
std::vector<cudaStream_t>::iterator end;
};

private:
CudaStreamHandlerImpl() : stop_thread(false) {}
CudaStreamHandlerImpl(const CudaStreamHandlerImpl &) = delete;
CudaStreamHandlerImpl &operator=(const CudaStreamHandlerImpl &) = delete;

void AsyncEventCleanup() {
while (!stop_thread.load()) {
std::for_each(m_events.begin(), m_events.end(), [this](cudaEvent_t &event) {
if (cudaEventQuery(event) == cudaSuccess) {
cudaEventDestroy(event);
event = nullptr;
}
});
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}

std::vector<cudaStream_t> m_streams;
std::vector<cudaEvent_t> m_events;
std::thread cleanup_thread;
std::atomic<bool> stop_thread;
};

// Public class for encapsulating the singleton operations
class CudaStreamHandler {
public:
CudaStreamHandler() = default;
~CudaStreamHandler() = default;

void AddStreams(int numStreams) { CudaStreamHandlerImpl::instance().AddStreams(numStreams); }

void join(cudaStream_t finalStream) { CudaStreamHandlerImpl::instance().join(finalStream); }

void Fork(cudaStream_t cudastream, int N) { CudaStreamHandlerImpl::instance().Fork(cudastream, N); }

// Get the custom iterator for CUDA streams
CudaStreamHandlerImpl::StreamIterator getIterator() {
return CudaStreamHandlerImpl::instance().getIterator();
}
};

#endif // CUDASTREAMHANDLER_HPP
76 changes: 0 additions & 76 deletions lib/include/kernel_helpers.h
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this file removed we can remove the comment in README

s2fft/README.md

Lines 350 to 352 in d77e9cb

The file [`lib/include/kernel_helpers.h`](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h) is adapted from
[code](https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281f3f5dbff579e8f5b00/lib/kernel_helpers.h) in [a tutorial on extending JAX](https://github.com/dfm/extending-jax) by
[Dan Foreman-Mackey](https://github.com/dfm) and licensed under a [MIT license](https://github.com/dfm/extending-jax/blob/371dca93c6405368fa8e71690afd3968d75f4bac/LICENSE).

This file was deleted.

Loading
Loading