| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "nikolaev_d_gather/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <cstddef> | ||
| 7 | #include <cstring> | ||
| 8 | #include <utility> | ||
| 9 | #include <vector> | ||
| 10 | |||
| 11 | #include "nikolaev_d_gather/common/include/common.hpp" | ||
| 12 | |||
| 13 | namespace nikolaev_d_gather { | ||
| 14 | |||
| 15 | namespace { | ||
| 16 | int GetTypeSize(MPI_Datatype datatype) { | ||
| 17 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
|
6 | if (datatype == MPI_INT) { |
| 18 | return sizeof(int); | ||
| 19 | } | ||
| 20 |
4/4✓ Branch 0 taken 2 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 2 times.
|
8 | if (datatype == MPI_FLOAT) { |
| 21 | return sizeof(float); | ||
| 22 | } | ||
| 23 |
2/4✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
4 | if (datatype == MPI_DOUBLE) { |
| 24 | 2 | return sizeof(double); | |
| 25 | } | ||
| 26 | return 0; | ||
| 27 | } | ||
| 28 | } // namespace | ||
| 29 | |||
| 30 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | NikolaevDGatherMPI::NikolaevDGatherMPI(const InType &in) { |
| 31 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 32 | GetInput() = in; | ||
| 33 | 6 | } | |
| 34 | |||
| 35 | 6 | bool NikolaevDGatherMPI::ValidationImpl() { | |
| 36 | const auto &input = GetInput(); | ||
| 37 | |||
| 38 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (input.count <= 0) { |
| 39 | return false; | ||
| 40 | } | ||
| 41 | |||
| 42 |
5/6✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 2 times.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
|
6 | if (input.datatype != MPI_INT && input.datatype != MPI_FLOAT && input.datatype != MPI_DOUBLE) { |
| 43 | return false; | ||
| 44 | } | ||
| 45 | |||
| 46 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (input.root < 0) { |
| 47 | return false; | ||
| 48 | } | ||
| 49 | |||
| 50 | const int type_size = GetTypeSize(input.datatype); | ||
| 51 | if (type_size <= 0) { | ||
| 52 | return false; | ||
| 53 | } | ||
| 54 | |||
| 55 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | const size_t expected_size = static_cast<size_t>(input.count) * static_cast<size_t>(type_size); |
| 56 | |||
| 57 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (input.data.size() != expected_size) { |
| 58 | return false; | ||
| 59 | } | ||
| 60 | |||
| 61 | 6 | int size = 1; | |
| 62 | 6 | MPI_Comm_size(MPI_COMM_WORLD, &size); | |
| 63 | |||
| 64 | 6 | return input.root < size; | |
| 65 | } | ||
| 66 | |||
| 67 | 6 | bool NikolaevDGatherMPI::PreProcessingImpl() { | |
| 68 | 6 | return true; | |
| 69 | } | ||
| 70 | |||
| 71 | namespace { | ||
| 72 | |||
| 73 | bool CheckGatherArgs(int sendcount, int recvcount, MPI_Datatype sendtype, MPI_Datatype recvtype, int rank, int root, | ||
| 74 | const void *recvbuf) { | ||
| 75 | 6 | if (sendcount != recvcount) { | |
| 76 | return false; | ||
| 77 | } | ||
| 78 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (sendtype != recvtype) { |
| 79 | return false; | ||
| 80 | } | ||
| 81 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (rank == root && recvbuf == nullptr) { |
| 82 | return false; | ||
| 83 | } | ||
| 84 | return true; | ||
| 85 | } | ||
| 86 | |||
| 87 | int GetBlockSize(int sendcount, MPI_Datatype datatype) { | ||
| 88 | 6 | int type_size = 0; | |
| 89 | 6 | MPI_Type_size(datatype, &type_size); | |
| 90 | 6 | return sendcount * type_size; | |
| 91 | } | ||
| 92 | |||
| 93 | 6 | void InitLocalData(const void *sendbuf, int block_sz, int rank, std::vector<char> &data, std::vector<int> &ranks) { | |
| 94 | 6 | data.resize(static_cast<size_t>(block_sz)); | |
| 95 | ranks.clear(); | ||
| 96 | ranks.push_back(rank); | ||
| 97 | |||
| 98 | const auto *send_ptr = static_cast<const char *>(sendbuf); | ||
| 99 | 6 | std::copy(send_ptr, send_ptr + block_sz, data.begin()); | |
| 100 | 6 | } | |
| 101 | |||
| 102 | 3 | void ReceiveFromChild(int sender_rank, int block_sz, MPI_Comm comm, std::vector<char> &data, std::vector<int> &ranks) { | |
| 103 | 3 | int num_ranks = 0; | |
| 104 | 3 | MPI_Recv(&num_ranks, 1, MPI_INT, sender_rank, 100, comm, MPI_STATUS_IGNORE); | |
| 105 | |||
| 106 | 3 | const size_t recv_size = static_cast<size_t>(num_ranks) * static_cast<size_t>(block_sz); | |
| 107 | |||
| 108 | 3 | std::vector<char> recv_data(recv_size); | |
| 109 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | MPI_Recv(recv_data.data(), static_cast<int>(recv_size), MPI_BYTE, sender_rank, 101, comm, MPI_STATUS_IGNORE); |
| 110 | |||
| 111 |
1/4✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
3 | std::vector<int> recv_ranks(static_cast<size_t>(num_ranks)); |
| 112 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | MPI_Recv(recv_ranks.data(), num_ranks, MPI_INT, sender_rank, 102, comm, MPI_STATUS_IGNORE); |
| 113 | |||
| 114 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | data.insert(data.end(), recv_data.begin(), recv_data.end()); |
| 115 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | ranks.insert(ranks.end(), recv_ranks.begin(), recv_ranks.end()); |
| 116 | 3 | } | |
| 117 | |||
| 118 | 3 | void SendToParent(int receiver_rank, int block_sz, MPI_Comm comm, const std::vector<char> &data, | |
| 119 | const std::vector<int> &ranks) { | ||
| 120 | 3 | const int num_ranks = static_cast<int>(ranks.size()); | |
| 121 | |||
| 122 | 3 | MPI_Send(&num_ranks, 1, MPI_INT, receiver_rank, 100, comm); | |
| 123 | 3 | MPI_Send(data.data(), num_ranks * block_sz, MPI_BYTE, receiver_rank, 101, comm); | |
| 124 | 3 | MPI_Send(ranks.data(), num_ranks, MPI_INT, receiver_rank, 102, comm); | |
| 125 | 3 | } | |
| 126 | |||
| 127 | 3 | void AssembleAtRoot(const std::vector<char> &data, const std::vector<int> &ranks, int size, int block_sz, | |
| 128 | void *recvbuf) { | ||
| 129 | auto *out = static_cast<char *>(recvbuf); | ||
| 130 | 3 | std::vector<char> full_data(static_cast<size_t>(size) * static_cast<size_t>(block_sz), 0); | |
| 131 | |||
| 132 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | for (size_t i = 0; i < ranks.size(); ++i) { |
| 133 | 6 | const int r = ranks[i]; | |
| 134 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (r >= 0 && r < size) { |
| 135 | 6 | const size_t src_offset = i * static_cast<size_t>(block_sz); | |
| 136 | 6 | const size_t dest_offset = static_cast<size_t>(r) * static_cast<size_t>(block_sz); | |
| 137 | const auto len = static_cast<size_t>(block_sz); | ||
| 138 | |||
| 139 | 6 | std::copy(data.begin() + static_cast<std::ptrdiff_t>(src_offset), | |
| 140 | 6 | data.begin() + static_cast<std::ptrdiff_t>(src_offset + len), | |
| 141 | full_data.begin() + static_cast<std::ptrdiff_t>(dest_offset)); | ||
| 142 | } | ||
| 143 | } | ||
| 144 | |||
| 145 | std::ranges::copy(full_data, out); | ||
| 146 | 3 | } | |
| 147 | |||
| 148 | 6 | int TreeGatherImpl(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, | |
| 149 | MPI_Datatype recvtype, int root, MPI_Comm comm) { | ||
| 150 | 6 | int rank = 0; | |
| 151 | 6 | int size = 1; | |
| 152 | 6 | MPI_Comm_rank(comm, &rank); | |
| 153 | 6 | MPI_Comm_size(comm, &size); | |
| 154 | |||
| 155 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (!CheckGatherArgs(sendcount, recvcount, sendtype, recvtype, rank, root, recvbuf)) { |
| 156 | return MPI_ERR_ARG; | ||
| 157 | } | ||
| 158 | |||
| 159 | const int block_sz = GetBlockSize(sendcount, sendtype); | ||
| 160 | 6 | const int rel_rank = (rank - root + size) % size; | |
| 161 | |||
| 162 | 6 | std::vector<char> current_data; | |
| 163 | 6 | std::vector<int> current_ranks; | |
| 164 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | InitLocalData(sendbuf, block_sz, rank, current_data, current_ranks); |
| 165 | |||
| 166 | int step = 1; | ||
| 167 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | while (step < size) { |
| 168 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rel_rank % (2 * step) == 0) { |
| 169 | 3 | const int sender_rel = rel_rank + step; | |
| 170 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if (sender_rel < size) { |
| 171 | 3 | const int sender_rank = (sender_rel + root) % size; | |
| 172 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | ReceiveFromChild(sender_rank, block_sz, comm, current_data, current_ranks); |
| 173 | } | ||
| 174 | } else { | ||
| 175 | 3 | const int receiver_rel = rel_rank - step; | |
| 176 | 3 | const int receiver_rank = (receiver_rel + root) % size; | |
| 177 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | SendToParent(receiver_rank, block_sz, comm, current_data, current_ranks); |
| 178 | break; | ||
| 179 | } | ||
| 180 | step *= 2; | ||
| 181 | } | ||
| 182 | |||
| 183 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank == root) { |
| 184 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | AssembleAtRoot(current_data, current_ranks, size, block_sz, recvbuf); |
| 185 | } | ||
| 186 | |||
| 187 | return MPI_SUCCESS; | ||
| 188 | } | ||
| 189 | } // namespace | ||
| 190 | |||
| 191 | 6 | bool NikolaevDGatherMPI::RunImpl() { | |
| 192 | 6 | int rank = 0; | |
| 193 | 6 | int size = 1; | |
| 194 | 6 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 195 | 6 | MPI_Comm_size(MPI_COMM_WORLD, &size); | |
| 196 | |||
| 197 | const auto &input = GetInput(); | ||
| 198 | |||
| 199 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
|
6 | const int type_size = GetTypeSize(input.datatype); |
| 200 | |||
| 201 | 6 | std::vector<char> recv_buffer; | |
| 202 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank == input.root) { |
| 203 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | recv_buffer.resize(static_cast<size_t>(input.count) * static_cast<size_t>(size) * static_cast<size_t>(type_size)); |
| 204 | } | ||
| 205 | |||
| 206 | int result = | ||
| 207 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | TreeGatherImpl(input.data.data(), input.count, input.datatype, rank == input.root ? recv_buffer.data() : nullptr, |
| 208 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | input.count, input.datatype, input.root, MPI_COMM_WORLD); |
| 209 | |||
| 210 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if (result != MPI_SUCCESS) { |
| 211 | return false; | ||
| 212 | } | ||
| 213 | |||
| 214 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank == input.root) { |
| 215 | GetOutput() = std::move(recv_buffer); | ||
| 216 | } else { | ||
| 217 | 3 | GetOutput() = std::vector<char>(); | |
| 218 | } | ||
| 219 | |||
| 220 | return true; | ||
| 221 | } | ||
| 222 | |||
| 223 | 6 | bool NikolaevDGatherMPI::PostProcessingImpl() { | |
| 224 | 6 | return true; | |
| 225 | } | ||
| 226 | |||
| 227 | } // namespace nikolaev_d_gather | ||
| 228 |