diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e0d2e870a5e..70b901ef680 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -112,7 +112,6 @@ build_gnn_cpu: script: - export PATH=/usr/local/sbin:/usr/sbin:/sbin:$PATH - - export PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH - git clone $CLONE_URL src - cd src - git checkout $HEAD_SHA diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b3ec5cd082..e60be8b9276 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -428,7 +428,9 @@ if(ACTS_BUILD_PLUGIN_GNN) endif() if(ACTS_GNN_ENABLE_TORCH) find_package(Torch REQUIRED) - add_subdirectory(thirdparty/FRNN) + if(ACTS_GNN_ENABLE_CUDA) + add_subdirectory(thirdparty/FRNN) + endif() endif() if(ACTS_GNN_ENABLE_MODULEMAP) if(ACTS_USE_SYSTEM_MODULEMAPGRAPH) diff --git a/Plugins/Gnn/CMakeLists.txt b/Plugins/Gnn/CMakeLists.txt index b481f89116b..809f7988817 100644 --- a/Plugins/Gnn/CMakeLists.txt +++ b/Plugins/Gnn/CMakeLists.txt @@ -88,7 +88,10 @@ endif() if(ACTS_GNN_ENABLE_TORCH) target_compile_definitions(ActsPluginGnn PUBLIC ACTS_GNN_TORCH_BACKEND) - target_link_libraries(ActsPluginGnn PRIVATE ${TORCH_LIBRARIES} frnn) + target_link_libraries(ActsPluginGnn PRIVATE ${TORCH_LIBRARIES}) + if(ACTS_GNN_ENABLE_CUDA) + target_link_libraries(ActsPluginGnn PRIVATE frnn) + endif() find_package(TorchScatter QUIET) if(NOT TARGET TorchScatter::TorchScatter) message( diff --git a/Plugins/Gnn/src/TorchEdgeClassifier.cpp b/Plugins/Gnn/src/TorchEdgeClassifier.cpp index 2bde400aae8..76c82cf8f3b 100644 --- a/Plugins/Gnn/src/TorchEdgeClassifier.cpp +++ b/Plugins/Gnn/src/TorchEdgeClassifier.cpp @@ -37,8 +37,7 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg, if (!torch::cuda::is_available()) { ACTS_DEBUG("Running on CPU..."); } else { - if (cfg.deviceID >= 0 && - static_cast(cfg.deviceID) < torch::cuda::device_count()) { + if (cfg.deviceID >= 0 && cfg.deviceID < torch::cuda::device_count()) { ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); device = torch::Device(torch::kCUDA, cfg.deviceID); } else { diff --git a/Plugins/Gnn/src/TorchMetricLearning.cpp b/Plugins/Gnn/src/TorchMetricLearning.cpp index cd2384d0dca..7428b6eab16 100644 --- a/Plugins/Gnn/src/TorchMetricLearning.cpp +++ b/Plugins/Gnn/src/TorchMetricLearning.cpp @@ -37,8 +37,7 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg, if (!torch::cuda::is_available()) { ACTS_DEBUG("Running on CPU..."); } else { - if (cfg.deviceID >= 0 && - static_cast(cfg.deviceID) < torch::cuda::device_count()) { + if (cfg.deviceID >= 0 && cfg.deviceID < torch::cuda::device_count()) { ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used."); device = torch::Device(torch::kCUDA, cfg.deviceID); } else { diff --git a/Plugins/Gnn/src/printCudaMemInfo.hpp b/Plugins/Gnn/src/printCudaMemInfo.hpp index f1d638853d0..4e84771cff1 100644 --- a/Plugins/Gnn/src/printCudaMemInfo.hpp +++ b/Plugins/Gnn/src/printCudaMemInfo.hpp @@ -9,7 +9,9 @@ #pragma once #include "Acts/Utilities/Logger.hpp" -#include "ActsPlugins/Gnn/detail/CudaUtils.hpp" +#ifndef ACTS_GNN_CPUONLY +#include +#endif #ifndef ACTS_GNN_CPUONLY #include diff --git a/thirdparty/FRNN/CMakeLists.txt b/thirdparty/FRNN/CMakeLists.txt index d5a7018e796..985aff8300f 100644 --- a/thirdparty/FRNN/CMakeLists.txt +++ b/thirdparty/FRNN/CMakeLists.txt @@ -8,6 +8,11 @@ cmake_minimum_required(VERSION 3.25) +# Set policy to allow FetchContent_Populate for projects without CMakeLists.txt +if(POLICY CMP0169) + cmake_policy(SET CMP0169 OLD) +endif() + include(FetchContent) message(STATUS "Building FRNN as part of the ACTS project") diff --git a/thirdparty/FRNN/CMakeLists.txt.in b/thirdparty/FRNN/CMakeLists.txt.in index b321e0d0415..33d4f0f4931 100644 --- a/thirdparty/FRNN/CMakeLists.txt.in +++ b/thirdparty/FRNN/CMakeLists.txt.in @@ -29,11 +29,10 @@ add_library(frnn STATIC target_include_directories(frnn SYSTEM PUBLIC - ${TORCH_INCLUDE_DIRS} $ ) -target_link_libraries(frnn PRIVATE ${TORCH_LIBRARIES}) +target_link_libraries(frnn PUBLIC torch) set_property(TARGET frnn PROPERTY CXX_STANDARD 17) set_property(TARGET frnn PROPERTY POSITION_INDEPENDENT_CODE ON)