| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "ashihmin_d_scatter_trans_from_one_to_all/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <type_traits> | ||
| 7 | #include <vector> | ||
| 8 | |||
| 9 | #include "ashihmin_d_scatter_trans_from_one_to_all/common/include/common.hpp" | ||
| 10 | |||
| 11 | namespace ashihmin_d_scatter_trans_from_one_to_all { | ||
| 12 | |||
| 13 | template <typename T> | ||
| 14 |
1/2✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
|
72 | AshihminDScatterTransFromOneToAllMPI<T>::AshihminDScatterTransFromOneToAllMPI(const InType &in) { |
| 15 | this->SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 16 | this->GetInput() = in; | ||
| 17 | this->GetOutput().clear(); | ||
| 18 | 72 | } | |
| 19 | |||
| 20 | template <typename T> | ||
| 21 | 72 | bool AshihminDScatterTransFromOneToAllMPI<T>::ValidationImpl() { | |
| 22 | const auto ¶ms = this->GetInput(); | ||
| 23 |
3/6✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 36 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 36 times.
|
72 | return (params.elements_per_process > 0) && (params.root >= 0) && (this->GetOutput().empty()); |
| 24 | } | ||
| 25 | |||
| 26 | template <typename T> | ||
| 27 | 72 | bool AshihminDScatterTransFromOneToAllMPI<T>::PreProcessingImpl() { | |
| 28 | 72 | return true; | |
| 29 | } | ||
| 30 | |||
| 31 | namespace { | ||
| 32 | |||
| 33 | template <typename T> | ||
| 34 | MPI_Datatype GetMPIDataType() { | ||
| 35 | if (std::is_same_v<T, int>) { | ||
| 36 | return MPI_INT; | ||
| 37 | } | ||
| 38 | if (std::is_same_v<T, float>) { | ||
| 39 | return MPI_FLOAT; | ||
| 40 | } | ||
| 41 | if (std::is_same_v<T, double>) { | ||
| 42 | return MPI_DOUBLE; | ||
| 43 | } | ||
| 44 | return MPI_DATATYPE_NULL; | ||
| 45 | } | ||
| 46 | |||
| 47 | inline int VirtualToRealRank(int virtual_rank, int root, int size) { | ||
| 48 | 36 | return (virtual_rank + root) % size; | |
| 49 | } | ||
| 50 | |||
| 51 | template <typename T> | ||
| 52 | 36 | void SendBlock(const ScatterParams ¶ms, const std::vector<T> &local_data, int dest_virtual, int dest_real, | |
| 53 | int elements_per_proc, int rank, int root, MPI_Datatype mpi_type) { | ||
| 54 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
36 | if (rank == root) { |
| 55 | 36 | const int offset = dest_virtual * elements_per_proc; | |
| 56 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
36 | if (offset + elements_per_proc <= static_cast<int>(params.data.size())) { |
| 57 | 36 | MPI_Send(params.data.data() + offset, elements_per_proc, mpi_type, dest_real, 0, MPI_COMM_WORLD); | |
| 58 | } | ||
| 59 | } else { | ||
| 60 | ✗ | MPI_Send(local_data.data(), elements_per_proc, mpi_type, dest_real, 0, MPI_COMM_WORLD); | |
| 61 | } | ||
| 62 | 36 | } | |
| 63 | |||
| 64 | template <typename T> | ||
| 65 | void ReceiveBlock(std::vector<T> &local_data, int src_real, int elements_per_proc, MPI_Datatype mpi_type) { | ||
| 66 |
1/6✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
|
18 | MPI_Recv(local_data.data(), elements_per_proc, mpi_type, src_real, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
| 67 | 18 | } | |
| 68 | |||
| 69 | } // namespace | ||
| 70 | |||
| 71 | template <typename T> | ||
| 72 | 72 | bool AshihminDScatterTransFromOneToAllMPI<T>::RunImpl() { | |
| 73 | const auto ¶ms = this->GetInput(); | ||
| 74 | |||
| 75 | 72 | int rank = 0; | |
| 76 | 72 | int size = 0; | |
| 77 | 72 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 78 | 72 | MPI_Comm_size(MPI_COMM_WORLD, &size); | |
| 79 | |||
| 80 | MPI_Datatype mpi_type = GetMPIDataType<T>(); | ||
| 81 | 72 | const int elements_per_proc = params.elements_per_process; | |
| 82 | 72 | const int root = params.root % size; | |
| 83 | |||
| 84 | 72 | const int virtual_rank = (rank - root + size) % size; | |
| 85 | |||
| 86 | 72 | std::vector<T> local_data(elements_per_proc); | |
| 87 | |||
| 88 | int mask = 1; | ||
| 89 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 18 times.
|
108 | while (mask < size) { |
| 90 |
2/2✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
|
72 | if ((virtual_rank & mask) == 0) { |
| 91 | 36 | const int dest_virtual = virtual_rank | mask; | |
| 92 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
36 | if (dest_virtual < size) { |
| 93 | const int dest_real = VirtualToRealRank(dest_virtual, root, size); | ||
| 94 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
36 | SendBlock(params, local_data, dest_virtual, dest_real, elements_per_proc, rank, root, mpi_type); |
| 95 | } | ||
| 96 | } else { | ||
| 97 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
36 | const int src_virtual = virtual_rank & ~mask; |
| 98 | const int src_real = VirtualToRealRank(src_virtual, root, size); | ||
| 99 | ReceiveBlock(local_data, src_real, elements_per_proc, mpi_type); | ||
| 100 | break; | ||
| 101 | } | ||
| 102 | 36 | mask <<= 1; | |
| 103 | } | ||
| 104 | |||
| 105 |
2/2✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
|
72 | if (rank == root) { |
| 106 | 36 | const int offset = virtual_rank * elements_per_proc; | |
| 107 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
36 | if (offset + elements_per_proc <= static_cast<int>(params.data.size())) { |
| 108 | 36 | std::copy(params.data.begin() + offset, params.data.begin() + offset + elements_per_proc, local_data.begin()); | |
| 109 | } | ||
| 110 | } | ||
| 111 | |||
| 112 |
1/2✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
|
72 | this->GetOutput() = local_data; |
| 113 | 72 | return true; | |
| 114 | } | ||
| 115 | |||
| 116 | template <typename T> | ||
| 117 | 36 | bool AshihminDScatterTransFromOneToAllMPI<T>::PostProcessingImpl() { | |
| 118 | 36 | return !this->GetOutput().empty(); | |
| 119 | } | ||
| 120 | |||
| 121 | template class AshihminDScatterTransFromOneToAllMPI<int>; | ||
| 122 | template class AshihminDScatterTransFromOneToAllMPI<float>; | ||
| 123 | template class AshihminDScatterTransFromOneToAllMPI<double>; | ||
| 124 | |||
| 125 | } // namespace ashihmin_d_scatter_trans_from_one_to_all | ||
| 126 |