Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ build_gnn_cpu:

script:
- export PATH=/usr/local/sbin:/usr/sbin:/sbin:$PATH
- export PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH
- echo $PATH
- git clone $CLONE_URL src
- cd src
Expand Down
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ if(ACTS_BUILD_PLUGIN_GNN)
endif()
if(ACTS_GNN_ENABLE_TORCH)
find_package(Torch REQUIRED)
add_subdirectory(thirdparty/FRNN)
if(ACTS_GNN_ENABLE_CUDA)
add_subdirectory(thirdparty/FRNN)
endif()
endif()
if(ACTS_GNN_ENABLE_MODULEMAP)
if(ACTS_USE_SYSTEM_MODULEMAPGRAPH)
Expand Down
5 changes: 4 additions & 1 deletion Plugins/Gnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ endif()

if(ACTS_GNN_ENABLE_TORCH)
target_compile_definitions(ActsPluginGnn PUBLIC ACTS_GNN_TORCH_BACKEND)
target_link_libraries(ActsPluginGnn PRIVATE ${TORCH_LIBRARIES} frnn)
target_link_libraries(ActsPluginGnn PRIVATE ${TORCH_LIBRARIES})
if(ACTS_GNN_ENABLE_CUDA)
target_link_libraries(ActsPluginGnn PRIVATE frnn)
endif()
find_package(TorchScatter QUIET)
if(NOT TARGET TorchScatter::TorchScatter)
message(
Expand Down
3 changes: 1 addition & 2 deletions Plugins/Gnn/src/TorchEdgeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
if (!torch::cuda::is_available()) {
ACTS_DEBUG("Running on CPU...");
} else {
if (cfg.deviceID >= 0 &&
static_cast<std::size_t>(cfg.deviceID) < torch::cuda::device_count()) {
if (cfg.deviceID >= 0 && cfg.deviceID < torch::cuda::device_count()) {
ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
device = torch::Device(torch::kCUDA, cfg.deviceID);
} else {
Expand Down
3 changes: 1 addition & 2 deletions Plugins/Gnn/src/TorchMetricLearning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ TorchMetricLearning::TorchMetricLearning(const Config &cfg,
if (!torch::cuda::is_available()) {
ACTS_DEBUG("Running on CPU...");
} else {
if (cfg.deviceID >= 0 &&
static_cast<std::size_t>(cfg.deviceID) < torch::cuda::device_count()) {
if (cfg.deviceID >= 0 && cfg.deviceID < torch::cuda::device_count()) {
ACTS_DEBUG("GPU device " << cfg.deviceID << " is being used.");
device = torch::Device(torch::kCUDA, cfg.deviceID);
} else {
Expand Down
4 changes: 3 additions & 1 deletion Plugins/Gnn/src/printCudaMemInfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#pragma once

#include "Acts/Utilities/Logger.hpp"
#include "ActsPlugins/Gnn/detail/CudaUtils.hpp"
#ifndef ACTS_GNN_CPUONLY
#include <ActsPlugins/Gnn/detail/CudaUtils.hpp>
#endif

#ifndef ACTS_GNN_CPUONLY
#include <cuda_runtime_api.h>
Expand Down
5 changes: 5 additions & 0 deletions thirdparty/FRNN/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

cmake_minimum_required(VERSION 3.25)

# Set policy to allow FetchContent_Populate for projects without CMakeLists.txt
if(POLICY CMP0169)
cmake_policy(SET CMP0169 OLD)
endif()

include(FetchContent)

message(STATUS "Building FRNN as part of the ACTS project")
Expand Down
3 changes: 1 addition & 2 deletions thirdparty/FRNN/CMakeLists.txt.in
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ add_library(frnn STATIC

target_include_directories(frnn
SYSTEM PUBLIC
${TORCH_INCLUDE_DIRS}
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/frnn/csrc>
)

target_link_libraries(frnn PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(frnn PUBLIC torch)

set_property(TARGET frnn PROPERTY CXX_STANDARD 17)
set_property(TARGET frnn PROPERTY POSITION_INDEPENDENT_CODE ON)
Expand Down
Loading