diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b43a0e47..b1d2ce93 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,6 +13,7 @@ set(PCMS_HEADERS pcms/field_communicator.h pcms/field_communicator2.h pcms/field_evaluation_methods.h + pcms/global_communicator.h pcms/memory_spaces.h pcms/types.h pcms/array_mask.h @@ -51,11 +52,13 @@ if(PCMS_ENABLE_OMEGA_H) list( APPEND PCMS_HEADERS - pcms/adapter/omega_h/omega_h_field.h + #pcms/adapter/omega_h/omega_h_field.h pcms/transfer_field.h pcms/transfer_field2.h pcms/uniform_grid.h pcms/point_search.h) + install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/pcms/adapter/omega_h/omega_h_field.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/pcms/adapter/omega_h) endif() find_package(Kokkos REQUIRED) diff --git a/src/pcms/coupler.h b/src/pcms/coupler.h index 020b0cab..f1d89c8d 100644 --- a/src/pcms/coupler.h +++ b/src/pcms/coupler.h @@ -1,5 +1,6 @@ #ifndef PCMS_COUPLER_H #define PCMS_COUPLER_H +#include "global_communicator.h" #include "pcms/common.h" #include "pcms/field_communicator.h" #include "pcms/adapter/omega_h/omega_h_field.h" @@ -112,7 +113,32 @@ class CoupledField private: std::unique_ptr coupled_field_; }; +template +class GlobalDataInterface +{ + public: + GlobalDataInterface( const std::string& name , MPI_Comm mpi_comm, redev::Channel& channel) + : mpi_comm_(mpi_comm), comm_(GlobalCommunicator(name, mpi_comm_, channel)), + type_info_(typeid(T)) + { + PCMS_FUNCTION_TIMER; + } + void Send(T* msg, std::string VarName, size_t msg_size, Mode mode = Mode::Synchronous) + { + PCMS_FUNCTION_TIMER; + comm_.Send(msg, VarName, msg_size, mode); + } + std::vector Receive(std::string VarName, size_t msg_size, Mode mode = Mode::Synchronous) + { + PCMS_FUNCTION_TIMER; + return comm_.Receive(VarName, msg_size, mode); + } +private: + MPI_Comm mpi_comm_; + const std::type_info& type_info_; + GlobalCommunicator comm_; +}; class Application { public: @@ -142,6 +168,12 @@ class Application } return &(it->second); } + template + std::unique_ptr> Add_GDI(std::string name, MPI_Comm mpi_comm) + { + PCMS_FUNCTION_TIMER; + return std::make_unique>(name, mpi_comm, channel_); // Use the existing applivatiocation channel + } void SendField(const std::string& name, Mode mode = Mode::Synchronous) { PCMS_FUNCTION_TIMER; diff --git a/src/pcms/global_communicator.h b/src/pcms/global_communicator.h new file mode 100644 index 00000000..2021de01 --- /dev/null +++ b/src/pcms/global_communicator.h @@ -0,0 +1,49 @@ +#ifndef PCMS_GLOBAL_COMMUNICATOR_H +#define PCMS_GLOBAL_COMMUNICATOR_H +#endif // PCMS_GLOBAL_COMMUNICATOR_H + +#include + +namespace pcms +{ + using redev::Mode; + template + struct GlobalCommunicator + { + using value_type = T; + public: + GlobalCommunicator(std::string name, MPI_Comm mpi_comm, redev::Channel& channel) + : mpi_comm(mpi_comm), + channel_(channel), + name_(std::move(name)) + { + PCMS_FUNCTION_TIMER; + comm_ = channel_.CreateComm(name_, mpi_comm, redev::CommType::Global ); + } + GlobalCommunicator(const GlobalCommunicator&) = delete; + GlobalCommunicator& operator=(const GlobalCommunicator&) = delete; + GlobalCommunicator(GlobalCommunicator&&)= default; + GlobalCommunicator& operator=(GlobalCommunicator&&) = default; + + void Send(T* msg, std::string VarName, size_t msg_size, Mode mode = Mode::Synchronous) + { + PCMS_FUNCTION_TIMER; + PCMS_ALWAYS_ASSERT(channel_.InSendCommunicationPhase()); + comm_.SetCommParams( VarName, msg_size); + comm_.Send(msg, mode); + } + std::vector Receive(std::string VarName, size_t msg_size, Mode mode = Mode::Synchronous) + { + PCMS_FUNCTION_TIMER; + PCMS_ALWAYS_ASSERT(channel_.InReceiveCommunicationPhase()); + comm_.SetCommParams(VarName, msg_size); + auto data = comm_.Recv(mode); + return data; + } + private: + MPI_Comm mpi_comm; + redev::Channel& channel_; + std::string name_; + redev::BidirectionalComm comm_; + }; +} \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 574abaf4..cd929998 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -96,7 +96,37 @@ if(PCMS_ENABLE_OMEGA_H) ${d3d16p} ignored) endif() - + add_exe(test_GDI) + tri_mpi_test( + TESTNAME + test_GDI + TIMEOUT + 20 + NAME1 + app + EXE1 + ./test_GDI + PROCS1 + 1 + ARGS1 + 1 + NAME2 + rdv + EXE2 + ./test_GDI + PROCS2 + 1 + ARGS2 + -1 + NAME3 + app + EXE3 + ./test_GDI + PROCS3 + 1 + ARGS3 + 0 + ) set(d3d8p ${PCMS_TEST_DATA_DIR}/d3d/d3d-full_9k_sfc_p8.osh/) add_exe(test_twoClientOverlap) if(HOST_NPROC GREATER_EQUAL 28) diff --git a/test/test_GDI.cpp b/test/test_GDI.cpp new file mode 100644 index 00000000..39744fd2 --- /dev/null +++ b/test/test_GDI.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include "test_support.h" +#include "pcms/adapter/omega_h/omega_h_field.h" + +static constexpr bool done = true; +static constexpr int COMM_ROUNDS = 1; + +void xgc_delta_f(MPI_Comm comm) +{ + pcms::Coupler coupler("proxy_couple", comm, false, {}); + pcms::Application* app = coupler.AddApplication("proxy_couple_xgc_delta_f"); + + const auto GDI = app->Add_GDI("global_comm", comm); + auto mean = std::vector(1); + mean[0] = 16; + do { + for (int i = 0; i < COMM_ROUNDS; ++i) { + app->BeginSendPhase(); + GDI->Send(mean.data(), "mean", mean.size()); + app->EndSendPhase(); + printf("delta Sent mean:%d\n", mean[0]); + app->BeginReceivePhase(); + mean = GDI->Receive("mean", mean.size()); + app->EndReceivePhase(); + mean[0] = mean[0]/2; + } + } while (!done); + printf("final Mean = %d\n", mean[0]); + assert(std::fabs(mean[0] - 1.0) < 1e-12); + printf("GDI test successful.\n"); +} +void xgc_total_f(MPI_Comm comm) +{ + pcms::Coupler coupler("proxy_couple", comm, false, {}); + pcms::Application* app = coupler.AddApplication("proxy_couple_xgc_total_f"); + + auto GDI = app->Add_GDI("global_comm", comm); + auto mean = std::vector(1); + do { + for (int i = 0; i < COMM_ROUNDS; ++i) { + app->BeginReceivePhase(); + mean = GDI->Receive("mean", mean.size()); + app->EndReceivePhase(); + printf("total Recieved mean:%d\n", mean[0]); + mean[0] = mean[0]/2; + app->BeginSendPhase(); + GDI->Send(mean.data(), "mean", mean.size()); + app->EndSendPhase(); + printf("total Sent mean:%d\n", mean[0]); + } + } while (!done); +} +void xgc_coupler(MPI_Comm comm) +{ + // Define Partition + redev::LO dim = 3; + redev::LOs ranks(1); + std::iota(ranks.begin(), ranks.end(), 0); + redev::Reals cuts = {0}; + auto partition = redev::Partition{redev::RCBPtn{dim, ranks, cuts}}; + + pcms::Coupler cpl("proxy_couple", comm, true, + partition); + auto* total_f = cpl.AddApplication("proxy_couple_xgc_total_f"); + auto* delta_f = cpl.AddApplication("proxy_couple_xgc_delta_f"); + + auto GDI_total = total_f->Add_GDI("global_comm", comm); + auto GDI_delta = delta_f->Add_GDI("global_comm", comm); + auto mean = std::vector(1); + do { + for (int i = 0; i < COMM_ROUNDS; ++i) { + delta_f->BeginReceivePhase(); + mean = GDI_delta->Receive("mean", 1); + delta_f->EndReceivePhase(); + printf("delta Received mean:%d\n", mean[0]); + mean[0] = mean[0]/2; + const auto msg_size = mean.size(); + total_f->BeginSendPhase(); + GDI_total->Send(mean.data(), "mean", msg_size); + total_f->EndSendPhase(); + printf("total sent mean:%d\n", mean[0]); + total_f->BeginReceivePhase(); + mean = GDI_total->Receive("mean", msg_size); + total_f->EndReceivePhase(); + printf("delta Received mean:%d\n", mean[0]); + mean[0] = mean[0]/2; + delta_f->BeginSendPhase(); + GDI_delta->Send(mean.data(), "mean", msg_size); + delta_f->EndSendPhase(); + printf("detla sent mean:%d\n", mean[0]); + } + } while (!done); +} + +int main(int argc, char** argv) +{ + MPI_Init(&argc, &argv); // MPI init + + OMEGA_H_CHECK(argc == 2); + const auto clientId = atoi(argv[1]); + REDEV_ALWAYS_ASSERT(clientId >= -1 && clientId <= 1); + + int color; + if (clientId == -1) + color = 0; // coupler + else if (clientId == 0) + color = 1; // client A + else if (clientId == 1) + color = 2; // client B + else + color = MPI_UNDEFINED; + + MPI_Comm subcomm; + MPI_Comm_split(MPI_COMM_WORLD, color, 0, &subcomm); + + switch (clientId) { + case -1: xgc_coupler(subcomm); break; + case 0: xgc_delta_f(subcomm); break; + case 1: xgc_total_f(subcomm); break; + default: + std::cerr << "Unhandled client id (should be -1, 0,1)\n"; + exit(EXIT_FAILURE); + } + MPI_Finalize(); + return 0; +}