diff --git a/source/op/pt/CMakeLists.txt b/source/op/pt/CMakeLists.txt index 81faa8c436..3bb34e622d 100644 --- a/source/op/pt/CMakeLists.txt +++ b/source/op/pt/CMakeLists.txt @@ -20,6 +20,16 @@ endif() find_package(MPI) if(MPI_FOUND) + include(CheckCXXSymbolExists) + set(CMAKE_REQUIRED_INCLUDES ${MPI_CXX_INCLUDE_DIRS}) + set(CMAKE_REQUIRED_LIBRARIES ${MPI_CXX_LIBRARIES}) + check_cxx_symbol_exists(MPIX_Query_cuda_support "mpi.h" CUDA_AWARE) + if(NOT CUDA_AWARE) + check_cxx_symbol_exists(MPIX_Query_cuda_support "mpi.h;mpi-ext.h" OMP_CUDA) + if(NOT OMP_CUDA) + target_compile_definitions(deepmd_op_pt PRIVATE NO_CUDA_AWARE) + endif() + endif() target_link_libraries(deepmd_op_pt PRIVATE MPI::MPI_CXX) target_compile_definitions(deepmd_op_pt PRIVATE USE_MPI) endif() diff --git a/source/op/pt/comm.cc b/source/op/pt/comm.cc index 8dce9e7081..a25dfbd542 100644 --- a/source/op/pt/comm.cc +++ b/source/op/pt/comm.cc @@ -100,7 +100,11 @@ class Border : public torch::autograd::Function { int version, subversion; MPI_Get_version(&version, &subversion); if (version >= 4) { +#ifdef NO_CUDA_AWARE + cuda_aware = 0; +#else cuda_aware = MPIX_Query_cuda_support(); +#endif } else { cuda_aware = 0; } @@ -215,7 +219,11 @@ class Border : public torch::autograd::Function { int version, subversion; MPI_Get_version(&version, &subversion); if (version >= 4) { +#ifdef NO_CUDA_AWARE + cuda_aware = 0; +#else cuda_aware = MPIX_Query_cuda_support(); +#endif } else { cuda_aware = 0; }