diff --git a/CMakeLists.txt b/CMakeLists.txt index d54bff9..68b8b74 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -274,18 +274,6 @@ if (CUDECOMP_BUILD_FORTRAN) target_link_libraries(cudecomp_fort PUBLIC MPI::MPI_Fortran) - # Test for MPI_Comm_f2c/c2f - try_compile( - TEST_F2C_RESULT - ${CMAKE_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/test_mpi_f2c.f90 - LINK_LIBRARIES MPI::MPI_Fortran - ) - if (NOT TEST_F2C_RESULT) - message(STATUS "Could not link MPI_Comm_f2c in Fortran module. Setting -DMPICH flag during module compilation.") - target_compile_definitions(cudecomp_fort PRIVATE MPICH) - endif() - install( TARGETS cudecomp_fort LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib diff --git a/cmake/test_mpi_f2c.f90 b/cmake/test_mpi_f2c.f90 deleted file mode 100644 index c1d55b0..0000000 --- a/cmake/test_mpi_f2c.f90 +++ /dev/null @@ -1,48 +0,0 @@ -! SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -! SPDX-License-Identifier: Apache-2.0 -! -! Licensed under the Apache License, Version 2.0 (the "License"); -! you may not use this file except in compliance with the License. -! You may obtain a copy of the License at -! -! http://www.apache.org/licenses/LICENSE-2.0 -! -! Unless required by applicable law or agreed to in writing, software -! distributed under the License is distributed on an "AS IS" BASIS, -! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -! See the License for the specific language governing permissions and -! limitations under the License. - -module test_f2c - use iso_c_binding - implicit none - - type, bind(c) :: MPI_C_Comm - integer(c_int64_t) :: comm - end type MPI_C_Comm - - type, bind(c) :: MPI_F_Comm - integer(c_int) :: comm - end type MPI_F_Comm - - interface - function MPI_Comm_f2c(fcomm) bind(C,name='MPI_Comm_f2c') result(res) - import - type(MPI_F_Comm), value :: fcomm - type(MPI_C_Comm) :: res - end function MPI_Comm_f2c - end interface -end module - -program main - use mpi - use test_f2c - implicit none - - type(MPI_F_Comm) :: fcomm - type(MPI_C_Comm) :: ccomm - - fcomm%comm = MPI_COMM_WORLD - - ccomm = MPI_Comm_f2c(fcomm) -end program diff --git a/docs/api/c_api.rst b/docs/api/c_api.rst index f5b7220..4fdb343 100644 --- a/docs/api/c_api.rst +++ b/docs/api/c_api.rst @@ -118,6 +118,14 @@ ____________ ------ +.. _cudecompInit_F-ref: + +cudecompInit_F +______________ +.. doxygenfunction:: cudecompInit_F + +------ + .. _cudecompFinalize-ref: cudecompFinalize diff --git a/include/cudecomp.h b/include/cudecomp.h index cc6071f..189f634 100644 --- a/include/cudecomp.h +++ b/include/cudecomp.h @@ -212,6 +212,16 @@ typedef struct { */ cudecompResult_t cudecompInit(cudecompHandle_t* handle, MPI_Comm mpi_comm); +/** + * @brief Initializes the cuDecomp library from an existing MPI communicator + * + * @param[out] handle A pointer to an uninitialized cudecompHandle_t + * @param[in] mpi_comm_f MPI communicator, in Fortran integer format, containing ranks to use with cuDecomp + * + * @return CUDECOMP_RESULT_SUCCESS on success or error code on failure. + */ +cudecompResult_t cudecompInit_F(cudecompHandle_t* handle, MPI_Fint mpi_comm_f); + /** * @brief Finalizes the cuDecomp library and frees associated resources * diff --git a/src/cudecomp.cc b/src/cudecomp.cc index 5129d27..dd82575 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -550,6 +550,18 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) { return CUDECOMP_RESULT_SUCCESS; } +cudecompResult_t cudecompInit_F(cudecompHandle_t* handle_in, MPI_Fint mpi_comm_f) { + using namespace cudecomp; + try { + MPI_Comm mpi_comm = MPI_Comm_f2c(mpi_comm_f); + cudecompInit(handle_in, mpi_comm); + } catch (const cudecomp::BaseException& e) { + std::cerr << e.what(); + return e.getResult(); + } + return CUDECOMP_RESULT_SUCCESS; +} + cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDesc_t* grid_desc_in, cudecompGridDescConfig_t* config, const cudecompGridDescAutotuneOptions_t* options) { diff --git a/src/cudecomp_m.cuf b/src/cudecomp_m.cuf index 3046100..8d207b6 100644 --- a/src/cudecomp_m.cuf +++ b/src/cudecomp_m.cuf @@ -71,21 +71,6 @@ module cudecomp ! types - ! MPI-related types -#ifndef MPICH - type, bind(c) :: MPI_C_Comm - integer(c_int64_t) :: comm - end type MPI_C_Comm -#else - type, bind(c) :: MPI_C_Comm - integer(c_int) :: comm - end type MPI_C_Comm -#endif - - type, bind(c) :: MPI_F_Comm - integer(c_int) :: comm - end type MPI_F_Comm - ! Opaque handle to cuDecomp handle type, bind(c) :: cudecompHandle type(c_ptr) :: member @@ -155,37 +140,20 @@ module cudecomp ! interfaces - ! MPI_Comm conversion functions -#ifndef MPICH - interface - function MPI_Comm_f2c(fcomm) bind(C,name='MPI_Comm_f2c') result(res) - import - type(MPI_F_comm), value :: fcomm - type(MPI_C_comm) :: res - end function MPI_Comm_f2c - - function MPI_Comm_c2f(ccomm) bind(C,name='MPI_Comm_c2f') result(res) - import - type(MPI_C_Comm), value :: ccomm - type(MPI_F_Comm) :: res - end function MPI_Comm_c2f - end interface -#endif - ! cuDecomp initialization/finalization functions - ! generic interface that takes either integer or type(MPI_F_Comm) communicator arguments + ! generic interface that takes either integer or type(MPI_Comm) communicator arguments interface cudecompInit - module procedure cudecompInit_MPI_F, cudecompInit_MPI_F08, cudecompInitType + module procedure cudecompInit_MPI_F, cudecompInit_MPI_F08 end interface cudecompInit interface - function cudecompInitC(handle, mpi_comm) bind(C, name="cudecompInit") result(res) + function cudecompInit_FC(handle, mpi_comm) bind(C, name="cudecompInit_F") result(res) import type(cudecompHandle) :: handle - type(MPI_C_Comm), value :: mpi_comm ! conversion to MPI_C_Comm done in module procedures cudecompInit*() + integer, value :: mpi_comm integer(c_int) :: res - end function cudecompInitC + end function cudecompInit_FC end interface interface @@ -523,10 +491,7 @@ contains integer :: comm integer(c_int) :: res - type(MPI_F_Comm) :: fComm - - fComm%comm = comm - res = cudecompInitType(handle, fComm) + res = cudecompInit_FC(handle, comm) end function cudecompInit_MPI_F function cudecompInit_MPI_F08(handle, comm) result(res) @@ -538,28 +503,9 @@ contains type(MPI_Comm) :: comm integer(c_int) :: res - type(MPI_F_Comm) :: fComm - - fComm%comm = comm%MPI_VAL - res = cudecompInitType(handle, fComm) + res = cudecompInit_FC(handle, comm%MPI_VAL) end function cudecompInit_MPI_F08 - function cudecompInitType(handle, fComm) result(res) - implicit none - type(cudecompHandle) :: handle - type(MPI_F_Comm) :: fComm - integer(c_int) :: res - - type(MPI_C_Comm) :: cComm -#ifndef MPICH - cComm = MPI_Comm_f2c(fComm) -#else - cComm= fComm -#endif - - res = cudecompInitC(handle, cComm) - end function cudecompInitType - ! cudecompGridDesc creation/manipulation functions function cudecompGridDescAutotuneOptionsSetDefaults(options) result(res) type(cudecompGridDescAutotuneOptions) :: options