Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 153 additions & 34 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,146 @@ cmake_minimum_required(VERSION 3.22)
project(PPLXKernels
VERSION 0.0.1
DESCRIPTION "PPLX Kernels"
LANGUAGES CXX CUDA)
LANGUAGES CXX)


set(ROCM_HOME "/opt/rocm" CACHE PATH "ROCM SDK INSTALLATION HOME")
if (NOT IS_DIRECTORY ${ROCM_HOME})
message(WARNING "ROCM_HOME ${ROCM_HOME} is not a directory")
endif()

if (LINUX)
# SDK Root in CMAKE config file; LINUX system defaults to ENV{ROCM_PATH}; WIN32 system defaults to ENV{HIP_PATH}
set(ENV{ROCM_PATH} ${ROCM_HOME})
endif()

if(NOT DEFINED HIP_CMAKE_PATH)
if(NOT DEFINED ENV{HIP_CMAKE_PATH})
# NOTE(yiakwy) : find_package(HIP) will first search for
# cmake/Modules/FindAMDDeviceLibs.cmake
# , then
# /opt/rocm/lib/cmake/AMDDeviceLibs/AMDDeviceLibsConfig.cmake
# this will add hip::host, hip::device dependencies to be linked by any hip targets (ROCM >= 6.x).
# Add hip-config.cmake to CMake module search path.
# set(HIP_CMAKE_PATH "${ROCM_HOME}/share/rocm/cmake" "${ROCM_HOME}/share/rocmcmakebuildtools/cmake/" CACHE PATH "Path to which HIP has been installed")
# NOTE(yiakwy) : adding ${ROCM_HOME}/lib/cmake/hip has conflicts with 3rdparty/mscclpp
set(ROCSHMEM_LIB_DIR "/root/rocshmem")
set(ROCSHMEM_CMAKE_PATH "${ROCSHMEM_LIB_DIR}/lib/cmake/rocshmem")

set(HIP_CMAKE_PATH
# NOTE (yiakwy) : by default rocm install in local directory ~/rocshmem
"${ROCSHMEM_CMAKE_PATH}"
"${ROCM_HOME}/lib/cmake/AMDDeviceLibs"
"${ROCM_HOME}/lib/cmake/amd_comgr"
"${ROCM_HOME}/lib/cmake/hsa-runtime64"
"${ROCM_HOME}/lib/cmake/hipcub"
"${ROCM_HOME}/lib/cmake/rccl"
"${ROCM_HOME}/lib/cmake/composable_kernel" CACHE PATH "Path to which HIP has been installed")
message(WARNING "System variable HIP_CMAKE_PATH is nonexist, defaults to ${HIP_CMAKE_PATH}")

set(CMAKE_PREFIX_PATH "${ROCM_HOME};${ROCM_HOME}/lib/cmake/hip;${ROCSHMEM_CMAKE_PATH};${CMAKE_PREFIX_PATH}")
else()
set(HIP_CMAKE_PATH $ENV{HIP_CMAKE_PATH} CACHE PATH "Path to which HIP has been installed")
endif()

set(CMAKE_MODULE_PATH "${HIP_CMAKE_PATH}" ${CMAKE_MODULE_PATH})

endif()

add_definitions(-Wall)
find_package(HIP QUIET)
if(HIP_FOUND)
message(STATUS "Found HIP: " ${HIP_VERSION})
execute_process(COMMAND bash -c "/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*'"
OUTPUT_VARIABLE CMAKE_HIP_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE)

message(STATUS "CMAKE_HIP_ARCHITECTURES : ${CMAKE_HIP_ARCHITECTURES}")

enable_language(HIP)

add_definitions(-DUSE_ROCM=1)

if (NOT DEFINED CMAKE_CXX_COMPILER)
find_program(CMAKE_CXX_COMPILER hipcc PATHS /opt/rocm)
endif()

# NOTE (yiakwy) : modern way to include ROCM tools
find_package(ROCmCMakeBuildTools PATHS /opt/rocm)
include(ROCMCreatePackage)
include(ROCMInstallTargets)
include(ROCMCheckTargetIds)

# NOTE (yiakwy) : include rocSHMEM
if(NOT TARGET roc::rocshmem)
# find_package(rocshmem REQUIRED)
find_package(rocshmem REQUIRED PATHS /root) #${ROCSHMEM_LIB_DIR})
endif()

else()
message(WARNING "Could not find HIP. Ensure that ROCM SDK is either installed in /opt/rocm or the variable HIP_CMAKE_PATH is set to point to the right location.")
endif()


find_package(CUDA QUIET)
if (CUDA_FOUND)
message(STATUS "FOUND CUDA: " ${CUDA_TOOLKIT_ROOT_DIR})

execute_process(COMMAND bash -c "/usr/bin/nvidia-smi --query-gpu=compute_cap --format=csv,noheader | grep -o -m1 '[0-9.]*'"
OUTPUT_VARIABLE CUDA_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE)

set(CUDA_SUPPORTED_ARCHS "9.0")

find_package(CUDAToolkit REQUIRED)
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)

find_package(NVSHMEM REQUIRED)

set(CMAKE_CUDA_ARCHITECTURES 90a CACHE STRING "CUDA architecture to target")
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)

find_package(NVSHMEM REQUIRED)
else()
message(WARNING "Could not find CUDA.")
endif()

if (NOT (HIP_FOUND) AND NOT (CUDA_FOUND))
message(FATAL "ROCM/CUDA SDK must be supported")
endif()


# === Configuration options ===
option(WITH_TESTS "Build tests" OFF)
option(WITH_BENCHMARKS "Build benchmarks" OFF)
set(CMAKE_CUDA_ARCHITECTURES 90a CACHE STRING "CUDA architecture to target")

# === CMake configuration ===
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INCLUDE_CURRENT_DIR ON)

set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_FLAGS_DEBUG "-g -ggdb -O0")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE BOOL "" FORCE )

# === Dependencies ===
include(FetchContent)
find_package(CUDAToolkit REQUIRED) # Modern replacement for find_package(CUDA)
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)

include(cmake/py_helper.cmake)
append_torch_cmake_prefix_path()
find_package(Torch REQUIRED)
find_package(NVSHMEM REQUIRED)

find_package(MPI REQUIRED)

if(WITH_TESTS)
enable_testing()
find_package(MPI REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_check_modules(NCCL nccl)

if (HIP_FOUND)
pkg_check_modules(RCCL rccl)
else()
pkg_check_modules(NCCL nccl)
endif()
endif()

# Create imported target for PyTorch
Expand All @@ -44,33 +158,38 @@ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
add_compile_definitions(Py_LIMITED_API=0x03090000)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})

# CUDA-specific compile options function
function(set_cuda_compile_options target)
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--threads=32 -O3>)
endfunction()

# === Library targets ===
if (HIP_FOUND)
include(cmake/rocm_helper.cmake)
endif()
include(cmake/py_helper.cmake)
add_subdirectory(all_to_all)
add_subdirectory(core)

# Main shared library
add_library(pplx_kernels SHARED
bindings/all_to_all_ops.cpp
bindings/bindings.cpp
)
target_link_libraries(pplx_kernels PUBLIC
all_to_all_internode_lib
all_to_all_intranode_lib
core_lib
torch::py_limited
Python::Module
CUDA::cuda_driver
CUDA::cudart
nvshmem::nvshmem_host
nvshmem::nvshmem_device
)
set_target_properties(pplx_kernels PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels
CUDA_SEPARABLE_COMPILATION ON
)
# NOTE (yiakwy) : TODO
# add_library(pplx_kernels SHARED
# bindings/all_to_all_ops.cpp
# bindings/bindings.cpp
# )

# target_link_libraries(pplx_kernels PUBLIC
# all_to_all_internode_lib
# all_to_all_intranode_lib
# core_lib
# torch::py_limited
# Python::Module

# # CUDA::cuda_driver
# # CUDA::cudart

# roc::rocshmem

# # nvshmem::nvshmem_host
# # nvshmem::nvshmem_device
# )

# set_target_properties(pplx_kernels PROPERTIES
# LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels
# CUDA_SEPARABLE_COMPILATION ON
# )
124 changes: 84 additions & 40 deletions csrc/all_to_all/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,61 +1,105 @@
# All-to-All library

add_library(all_to_all_common STATIC
set(all_to_all_srcs
all_to_all.cpp
)

target_link_libraries(all_to_all_common PUBLIC
CUDA::cudart
)

add_library(all_to_all_intranode_lib STATIC
set(all_to_all_intranode_lib_srcs
intranode_combine.cu
intranode_dispatch.cu
intranode.cpp
)
target_link_libraries(all_to_all_intranode_lib PUBLIC
all_to_all_common
CUDA::cudart
)
target_link_libraries(all_to_all_intranode_lib INTERFACE
nvshmem::nvshmem_host
)
target_include_directories(all_to_all_intranode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
set_cuda_compile_options(all_to_all_intranode_lib)

add_library(all_to_all_internode_lib STATIC
set(all_to_all_internode_lib_srcs
internode_combine.cu
internode_dispatch.cu
internode.cpp
)
target_link_libraries(all_to_all_internode_lib PUBLIC
all_to_all_common
CUDA::cudart
)
target_link_libraries(all_to_all_internode_lib INTERFACE
nvshmem::nvshmem_host
)
target_include_directories(all_to_all_internode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
set_cuda_compile_options(all_to_all_internode_lib)

if(WITH_TESTS)
# All-to-All test
add_executable(test_all_to_all
test_all_to_all.cpp
if (HIP_FOUND)
rocshmem_add_library(all_to_all_common
"${all_to_all_srcs}"
)
target_link_libraries(test_all_to_all PUBLIC
all_to_all_intranode_lib
all_to_all_internode_lib
core_lib

# add all_to_all_intranode_lib
rocshmem_add_library(all_to_all_intranode_lib
"${all_to_all_intranode_lib_srcs}"
DEPS
all_to_all_common
)

rocshmem_add_library(all_to_all_internode_lib
"${all_to_all_internode_lib_srcs}"
DEPS
all_to_all_common
)
else()
add_library(all_to_all_common STATIC
${all_to_all_srcs}
)

target_link_libraries(all_to_all_common PUBLIC
CUDA::cudart
CUDA::cuda_driver
MPI::MPI_CXX
)

# add all_to_all_intranode_lib
add_library(all_to_all_intranode_lib STATIC
"${all_to_all_intranode_lib_srcs}"
)
target_link_libraries(all_to_all_intranode_lib PUBLIC
all_to_all_common
CUDA::cudart
)
target_link_libraries(all_to_all_intranode_lib INTERFACE
nvshmem::nvshmem_host
)
set_cuda_compile_options(test_all_to_all)
add_test(NAME AllToAllTest
COMMAND ${MPIEXEC_EXECUTABLE} -np 4 $<TARGET_FILE:test_all_to_all>)
set_tests_properties(AllToAllTest PROPERTIES ENVIRONMENT "NVSHMEM_REMOTE_TRANSPORT=None")
target_include_directories(all_to_all_intranode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
set_cuda_compile_options(all_to_all_intranode_lib)

# add all_to_all_internode_lib
add_library(all_to_all_internode_lib STATIC
"${all_to_all_internode_lib_srcs}"
)
target_link_libraries(all_to_all_internode_lib PUBLIC
all_to_all_common
CUDA::cudart
)
target_link_libraries(all_to_all_internode_lib INTERFACE
nvshmem::nvshmem_host
)
target_include_directories(all_to_all_internode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR})
set_cuda_compile_options(all_to_all_internode_lib)
endif()

if(WITH_TESTS OR HIP_FOUND) # OR HIP_FOUND

if (HIP_FOUND)
rocshmem_add_executable(test_all_to_all
"internode.h;intranode.h;test_all_to_all.cpp"
DEPS
all_to_all_intranode_lib
# all_to_all_internode_lib
core_lib
roctx64
)
else()
# All-to-All test
add_executable(test_all_to_all
test_all_to_all.cpp
)
target_link_libraries(test_all_to_all PUBLIC
all_to_all_intranode_lib
all_to_all_internode_lib
core_lib
CUDA::cudart
CUDA::cuda_driver
MPI::MPI_CXX
nvshmem::nvshmem_host
)
set_cuda_compile_options(test_all_to_all)
add_test(NAME AllToAllTest
COMMAND ${MPIEXEC_EXECUTABLE} -np 4 $<TARGET_FILE:test_all_to_all>)
set_tests_properties(AllToAllTest PROPERTIES ENVIRONMENT "NVSHMEM_REMOTE_TRANSPORT=None")
endif()
endif()

if (WITH_BENCHMARKS)
Expand Down
4 changes: 4 additions & 0 deletions csrc/all_to_all/internode.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#ifdef USE_ROCM
#include "core/hip_dist_defs.h"

#else
#include <nvshmem.h>
#endif

#include <cassert>
#include <cstdint>
Expand Down
5 changes: 5 additions & 0 deletions csrc/all_to_all/internode.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

#include <cstddef>
#include <cstdint>

#ifdef USE_ROCM
#include "core/hip_cuda_dtype_defs.h"
#else
#include <cuda_bf16.h>
#endif // USE_ROCM

#include "all_to_all/all_to_all.h"
#include "core/buffer.h"
Expand Down
Loading
Loading