Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions source/op/pt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ endif()

find_package(MPI)
if(MPI_FOUND)
include(CheckCXXSymbolExists)
set(CMAKE_REQUIRED_INCLUDES ${MPI_CXX_INCLUDE_DIRS})
set(CMAKE_REQUIRED_LIBRARIES ${MPI_CXX_LIBRARIES})
check_cxx_symbol_exists(MPIX_Query_cuda_support "mpi.h" CUDA_AWARE)
Comment thread
CaRoLZhangxy marked this conversation as resolved.
Comment thread
CaRoLZhangxy marked this conversation as resolved.
if(NOT CUDA_AWARE)
check_cxx_symbol_exists(MPIX_Query_cuda_support "mpi.h;mpi-ext.h" OMP_CUDA)
if(NOT OMP_CUDA)
target_compile_definitions(deepmd_op_pt PRIVATE NO_CUDA_AWARE)
endif()
endif()
target_link_libraries(deepmd_op_pt PRIVATE MPI::MPI_CXX)
target_compile_definitions(deepmd_op_pt PRIVATE USE_MPI)
endif()
Expand Down
8 changes: 8 additions & 0 deletions source/op/pt/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ class Border : public torch::autograd::Function<Border> {
int version, subversion;
MPI_Get_version(&version, &subversion);
if (version >= 4) {
#ifdef NO_CUDA_AWARE
cuda_aware = 0;
#else
cuda_aware = MPIX_Query_cuda_support();
#endif
} else {
cuda_aware = 0;
}
Expand Down Expand Up @@ -215,7 +219,11 @@ class Border : public torch::autograd::Function<Border> {
int version, subversion;
MPI_Get_version(&version, &subversion);
if (version >= 4) {
#ifdef NO_CUDA_AWARE
cuda_aware = 0;
#else
cuda_aware = MPIX_Query_cuda_support();
#endif
} else {
cuda_aware = 0;
}
Expand Down