| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "nalitov_d_broadcast/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <cstring> | ||
| 7 | #include <stdexcept> | ||
| 8 | #include <variant> | ||
| 9 | #include <vector> | ||
| 10 | |||
| 11 | #include "nalitov_d_broadcast/common/include/common.hpp" | ||
| 12 | |||
| 13 | namespace nalitov_d_broadcast { | ||
| 14 | |||
| 15 | namespace { | ||
| 16 | |||
| 17 | constexpr int kBroadcastTag = 0; | ||
| 18 | |||
| 19 | int ValidateBroadcastArgs(const void *buffer, int count, int root, int comm_size) { | ||
| 20 | 94 | if (count < 0) { | |
| 21 | return MPI_ERR_COUNT; | ||
| 22 | } | ||
| 23 | |||
| 24 |
1/2✓ Branch 0 taken 94 times.
✗ Branch 1 not taken.
|
94 | if (comm_size <= 0) { |
| 25 | return MPI_ERR_COMM; | ||
| 26 | } | ||
| 27 | |||
| 28 |
1/2✓ Branch 0 taken 94 times.
✗ Branch 1 not taken.
|
94 | if (root < 0 || root >= comm_size) { |
| 29 | return MPI_ERR_ROOT; | ||
| 30 | } | ||
| 31 | |||
| 32 |
1/2✓ Branch 0 taken 94 times.
✗ Branch 1 not taken.
|
94 | if (count == 0) { |
| 33 | return MPI_SUCCESS; | ||
| 34 | } | ||
| 35 | |||
| 36 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 94 times.
|
94 | if (buffer == nullptr) { |
| 37 | return MPI_ERR_BUFFER; | ||
| 38 | } | ||
| 39 | |||
| 40 | return MPI_SUCCESS; | ||
| 41 | } | ||
| 42 | |||
| 43 | 94 | int TreeBroadcastStep(void *buffer, int count, MPI_Datatype datatype, int root, int virtual_rank, int mask, | |
| 44 | int comm_size, MPI_Comm comm) { | ||
| 45 |
2/2✓ Branch 0 taken 47 times.
✓ Branch 1 taken 47 times.
|
94 | if (virtual_rank < mask) { |
| 46 | 47 | const int dest_virtual = virtual_rank + mask; | |
| 47 |
1/2✓ Branch 0 taken 47 times.
✗ Branch 1 not taken.
|
47 | if (dest_virtual >= comm_size) { |
| 48 | return MPI_SUCCESS; | ||
| 49 | } | ||
| 50 | |||
| 51 | 47 | const int dest_rank = (dest_virtual + root) % comm_size; | |
| 52 | 47 | return MPI_Send(buffer, count, datatype, dest_rank, kBroadcastTag, comm); | |
| 53 | } | ||
| 54 | |||
| 55 |
1/2✓ Branch 0 taken 47 times.
✗ Branch 1 not taken.
|
47 | if (virtual_rank < (mask << 1)) { |
| 56 | 47 | const int src_virtual = virtual_rank - mask; | |
| 57 | 47 | const int src_rank = (src_virtual + root) % comm_size; | |
| 58 | 47 | return MPI_Recv(buffer, count, datatype, src_rank, kBroadcastTag, comm, MPI_STATUS_IGNORE); | |
| 59 | } | ||
| 60 | |||
| 61 | return MPI_SUCCESS; | ||
| 62 | } | ||
| 63 | |||
| 64 | 94 | int TreeBroadcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm) { | |
| 65 | 94 | int comm_size = 0; | |
| 66 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 94 times.
|
94 | if (MPI_Comm_size(comm, &comm_size) != MPI_SUCCESS) { |
| 67 | return MPI_ERR_COMM; | ||
| 68 | } | ||
| 69 | |||
| 70 | 94 | int my_rank = 0; | |
| 71 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 94 times.
|
94 | if (MPI_Comm_rank(comm, &my_rank) != MPI_SUCCESS) { |
| 72 | return MPI_ERR_COMM; | ||
| 73 | } | ||
| 74 | |||
| 75 |
1/2✓ Branch 0 taken 94 times.
✗ Branch 1 not taken.
|
94 | const int validation = ValidateBroadcastArgs(buffer, count, root, comm_size); |
| 76 | if (validation != MPI_SUCCESS) { | ||
| 77 | ✗ | return validation; | |
| 78 | } | ||
| 79 | |||
| 80 | 94 | int type_size = 0; | |
| 81 |
2/4✓ Branch 1 taken 94 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 94 times.
✗ Branch 4 not taken.
|
94 | if (MPI_Type_size(datatype, &type_size) != MPI_SUCCESS || type_size <= 0) { |
| 82 | return MPI_ERR_TYPE; | ||
| 83 | } | ||
| 84 | |||
| 85 | 94 | const int virtual_rank = (my_rank - root + comm_size) % comm_size; | |
| 86 | |||
| 87 |
2/2✓ Branch 0 taken 94 times.
✓ Branch 1 taken 94 times.
|
188 | for (int mask = 1; mask < comm_size; mask <<= 1) { |
| 88 | 94 | const int status = TreeBroadcastStep(buffer, count, datatype, root, virtual_rank, mask, comm_size, comm); | |
| 89 | |||
| 90 |
1/2✓ Branch 0 taken 94 times.
✗ Branch 1 not taken.
|
94 | if (status != MPI_SUCCESS) { |
| 91 | return status; | ||
| 92 | } | ||
| 93 | } | ||
| 94 | |||
| 95 | return MPI_SUCCESS; | ||
| 96 | } | ||
| 97 | |||
| 98 | template <typename T> | ||
| 99 | bool BroadcastScalar(T *value, MPI_Datatype datatype, int root_proc, MPI_Comm comm) { | ||
| 100 | return NalitovDBroadcast(static_cast<void *>(value), 1, datatype, root_proc, comm) == MPI_SUCCESS; | ||
| 101 | } | ||
| 102 | |||
| 103 | } // namespace | ||
| 104 | |||
| 105 | ✗ | int NalitovDBroadcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm) { | |
| 106 |
1/2✓ Branch 7 taken 32 times.
✗ Branch 8 not taken.
|
94 | return TreeBroadcast(buffer, count, datatype, root, comm); |
| 107 | } | ||
| 108 | |||
| 109 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | NalitovDBroadcastMPI::NalitovDBroadcastMPI(const InType &in) { |
| 110 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 111 | GetInput() = in; | ||
| 112 | |||
| 113 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 22 times.
|
32 | if (std::holds_alternative<std::vector<int>>(in.data)) { |
| 114 | const auto &src_vec = std::get<std::vector<int>>(in.data); | ||
| 115 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
20 | GetOutput() = InTypeVariant{std::vector<int>(src_vec.size(), 0)}; |
| 116 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 14 times.
|
22 | } else if (std::holds_alternative<std::vector<float>>(in.data)) { |
| 117 | const auto &src_vec = std::get<std::vector<float>>(in.data); | ||
| 118 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
16 | GetOutput() = InTypeVariant{std::vector<float>(src_vec.size(), 0.0F)}; |
| 119 |
1/2✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
|
14 | } else if (std::holds_alternative<std::vector<double>>(in.data)) { |
| 120 | const auto &src_vec = std::get<std::vector<double>>(in.data); | ||
| 121 |
1/2✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
|
28 | GetOutput() = InTypeVariant{std::vector<double>(src_vec.size(), 0.0)}; |
| 122 | } else { | ||
| 123 | ✗ | throw std::runtime_error("Unsupported data type"); | |
| 124 | } | ||
| 125 | 32 | } | |
| 126 | |||
| 127 | 32 | bool NalitovDBroadcastMPI::ValidationImpl() { | |
| 128 | 32 | int init_flag = 0; | |
| 129 | 32 | MPI_Initialized(&init_flag); | |
| 130 | |||
| 131 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
|
32 | if (init_flag == 0) { |
| 132 | return false; | ||
| 133 | } | ||
| 134 | |||
| 135 | 32 | int comm_size = 0; | |
| 136 | 32 | MPI_Comm_size(MPI_COMM_WORLD, &comm_size); | |
| 137 | |||
| 138 | const auto &input_data = GetInput(); | ||
| 139 |
2/2✓ Branch 0 taken 14 times.
✓ Branch 1 taken 8 times.
|
22 | const bool valid_type = std::holds_alternative<std::vector<int>>(input_data.data) || |
| 140 |
3/4✓ Branch 0 taken 22 times.
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 14 times.
|
46 | std::holds_alternative<std::vector<float>>(input_data.data) || |
| 141 | std::holds_alternative<std::vector<double>>(input_data.data); | ||
| 142 | |||
| 143 | if (!valid_type) { | ||
| 144 | return false; | ||
| 145 | } | ||
| 146 | |||
| 147 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 32 times.
|
32 | if (input_data.root < 0 || input_data.root >= comm_size) { |
| 148 | return false; | ||
| 149 | } | ||
| 150 | |||
| 151 | return true; | ||
| 152 | } | ||
| 153 | |||
| 154 | 32 | bool NalitovDBroadcastMPI::PreProcessingImpl() { | |
| 155 | 32 | return true; | |
| 156 | } | ||
| 157 | |||
| 158 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | bool NalitovDBroadcastMPI::RunImpl() { |
| 159 | try { | ||
| 160 | const auto &input_data = GetInput(); | ||
| 161 | 32 | int proc_rank = 0; | |
| 162 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | MPI_Comm_rank(MPI_COMM_WORLD, &proc_rank); |
| 163 | |||
| 164 | 32 | int root_proc = 0; | |
| 165 |
2/2✓ Branch 0 taken 16 times.
✓ Branch 1 taken 16 times.
|
32 | if (proc_rank == 0) { |
| 166 | 16 | root_proc = input_data.root; | |
| 167 | } | ||
| 168 | |||
| 169 |
1/2✓ Branch 0 taken 32 times.
✗ Branch 1 not taken.
|
32 | if (!BroadcastScalar(&root_proc, MPI_INT, 0, MPI_COMM_WORLD)) { |
| 170 | return false; | ||
| 171 | } | ||
| 172 | |||
| 173 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 22 times.
|
32 | if (std::holds_alternative<std::vector<int>>(input_data.data)) { |
| 174 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | return ProcessVector<int>(input_data, proc_rank, root_proc, MPI_INT); |
| 175 | } | ||
| 176 | |||
| 177 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 14 times.
|
22 | if (std::holds_alternative<std::vector<float>>(input_data.data)) { |
| 178 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | return ProcessVector<float>(input_data, proc_rank, root_proc, MPI_FLOAT); |
| 179 | } | ||
| 180 | |||
| 181 |
1/2✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
|
14 | if (std::holds_alternative<std::vector<double>>(input_data.data)) { |
| 182 |
1/2✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
|
14 | return ProcessVector<double>(input_data, proc_rank, root_proc, MPI_DOUBLE); |
| 183 | } | ||
| 184 | |||
| 185 | return false; | ||
| 186 | ✗ | } catch (...) { | |
| 187 | return false; | ||
| 188 | ✗ | } | |
| 189 | } | ||
| 190 | |||
| 191 | template <typename T> | ||
| 192 | 64 | bool NalitovDBroadcastMPI::ProcessVector(const InType &input_data, int proc_rank, int root_proc, | |
| 193 | MPI_Datatype mpi_dtype) { | ||
| 194 | 64 | int elem_count = 0; | |
| 195 | const bool is_root = (proc_rank == root_proc); | ||
| 196 | |||
| 197 |
2/2✓ Branch 0 taken 16 times.
✓ Branch 1 taken 16 times.
|
64 | if (is_root) { |
| 198 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
|
32 | if (!std::holds_alternative<std::vector<T>>(input_data.data)) { |
| 199 | return false; | ||
| 200 | } | ||
| 201 | 32 | elem_count = static_cast<int>(std::get<std::vector<T>>(input_data.data).size()); | |
| 202 | } | ||
| 203 | |||
| 204 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
|
64 | if (!BroadcastScalar(&elem_count, MPI_INT, root_proc, MPI_COMM_WORLD)) { |
| 205 | return false; | ||
| 206 | } | ||
| 207 | |||
| 208 | auto &output_result = GetOutput(); | ||
| 209 | |||
| 210 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
|
64 | if (!std::holds_alternative<std::vector<T>>(output_result)) { |
| 211 | ✗ | output_result = InTypeVariant{std::vector<T>()}; | |
| 212 | } | ||
| 213 | |||
| 214 | auto &dest_buffer = std::get<std::vector<T>>(output_result); | ||
| 215 | |||
| 216 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 30 times.
|
64 | if (elem_count == 0) { |
| 217 | dest_buffer.clear(); | ||
| 218 | 4 | return true; | |
| 219 | } | ||
| 220 | |||
| 221 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 30 times.
|
60 | if (static_cast<int>(dest_buffer.size()) != elem_count) { |
| 222 | ✗ | dest_buffer.resize(elem_count); | |
| 223 | } | ||
| 224 | |||
| 225 |
2/2✓ Branch 0 taken 15 times.
✓ Branch 1 taken 15 times.
|
60 | if (is_root) { |
| 226 | const auto &src_buffer = std::get<std::vector<T>>(input_data.data); | ||
| 227 | std::ranges::copy(src_buffer, dest_buffer.begin()); | ||
| 228 | } | ||
| 229 | |||
| 230 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 30 times.
|
60 | if (NalitovDBroadcast(dest_buffer.data(), elem_count, mpi_dtype, root_proc, MPI_COMM_WORLD) != MPI_SUCCESS) { |
| 231 | return false; | ||
| 232 | } | ||
| 233 | |||
| 234 | return true; | ||
| 235 | } | ||
| 236 | |||
| 237 | 32 | bool NalitovDBroadcastMPI::PostProcessingImpl() { | |
| 238 | 32 | return true; | |
| 239 | } | ||
| 240 | |||
| 241 | template bool NalitovDBroadcastMPI::ProcessVector<int>(const InType &input, int rank, int root, MPI_Datatype mpi_type); | ||
| 242 | template bool NalitovDBroadcastMPI::ProcessVector<float>(const InType &input, int rank, int root, | ||
| 243 | MPI_Datatype mpi_type); | ||
| 244 | template bool NalitovDBroadcastMPI::ProcessVector<double>(const InType &input, int rank, int root, | ||
| 245 | MPI_Datatype mpi_type); | ||
| 246 | |||
| 247 | } // namespace nalitov_d_broadcast | ||
| 248 |