| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "smyshlaev_a_mat_mul/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <cstddef> | ||
| 7 | #include <vector> | ||
| 8 | |||
| 9 | #include "smyshlaev_a_mat_mul/common/include/common.hpp" | ||
| 10 | |||
| 11 | namespace smyshlaev_a_mat_mul { | ||
| 12 | |||
| 13 | namespace { | ||
| 14 | |||
| 15 | void CalculateDistribution(int total_len, int proc_count, std::vector<int> &counts, std::vector<int> &offsets) { | ||
| 16 | 36 | const int chunk = total_len / proc_count; | |
| 17 | 36 | const int remainder = total_len % proc_count; | |
| 18 | int offset = 0; | ||
| 19 |
4/4✓ Branch 0 taken 36 times.
✓ Branch 1 taken 18 times.
✓ Branch 2 taken 36 times.
✓ Branch 3 taken 18 times.
|
108 | for (int i = 0; i < proc_count; i++) { |
| 20 |
4/4✓ Branch 0 taken 30 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 30 times.
✓ Branch 3 taken 6 times.
|
132 | counts[i] = chunk + (i < remainder ? 1 : 0); |
| 21 | 72 | offsets[i] = offset; | |
| 22 | 72 | offset += counts[i]; | |
| 23 | } | ||
| 24 | } | ||
| 25 | } // namespace | ||
| 26 | |||
| 27 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | SmyshlaevAMatMulMPI::SmyshlaevAMatMulMPI(const InType &in) { |
| 28 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 29 | |||
| 30 | 24 | int rank = 0; | |
| 31 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
| 32 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 33 | GetInput() = in; | ||
| 34 | } | ||
| 35 | 24 | } | |
| 36 | |||
| 37 | 24 | bool SmyshlaevAMatMulMPI::ValidationImpl() { | |
| 38 | 24 | int rank = 0; | |
| 39 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 40 | 24 | int error_flag = 0; | |
| 41 | |||
| 42 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 43 | const auto &num_rows_a = std::get<0>(GetInput()); | ||
| 44 | const auto &mat_a = std::get<1>(GetInput()); | ||
| 45 | const auto &num_rows_b = std::get<2>(GetInput()); | ||
| 46 | const auto &mat_b = std::get<3>(GetInput()); | ||
| 47 | |||
| 48 |
3/6✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
|
12 | bool is_invalid = (num_rows_a <= 0 || num_rows_b <= 0) || (mat_a.empty() || mat_b.empty()) || |
| 49 |
3/6✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 12 times.
|
24 | (mat_a.size() % num_rows_a != 0) || (mat_b.size() % num_rows_b != 0); |
| 50 | |||
| 51 | if (is_invalid) { | ||
| 52 | ✗ | error_flag = 1; | |
| 53 | } else { | ||
| 54 | 12 | const auto &num_cols_a = static_cast<int>(mat_a.size()) / num_rows_a; | |
| 55 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if (num_cols_a != num_rows_b) { |
| 56 | ✗ | error_flag = 1; | |
| 57 | } | ||
| 58 | } | ||
| 59 | } | ||
| 60 | 24 | MPI_Bcast(&error_flag, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 61 | 24 | return (error_flag == 0); | |
| 62 | } | ||
| 63 | |||
| 64 | 24 | bool SmyshlaevAMatMulMPI::PreProcessingImpl() { | |
| 65 | 24 | int rank = 0; | |
| 66 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 67 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 68 | const auto &num_rows_b = std::get<2>(GetInput()); | ||
| 69 | const auto &mat_b = std::get<3>(GetInput()); | ||
| 70 | 12 | const auto num_cols_b = static_cast<int>(mat_b.size()) / num_rows_b; | |
| 71 | |||
| 72 | 12 | mat_b_transposed_.resize(mat_b.size()); | |
| 73 | |||
| 74 |
2/2✓ Branch 0 taken 81 times.
✓ Branch 1 taken 12 times.
|
93 | for (int i = 0; i < num_rows_b; ++i) { |
| 75 |
2/2✓ Branch 0 taken 1220 times.
✓ Branch 1 taken 81 times.
|
1301 | for (int j = 0; j < num_cols_b; ++j) { |
| 76 | 1220 | mat_b_transposed_[(j * num_rows_b) + i] = mat_b[(i * num_cols_b) + j]; | |
| 77 | } | ||
| 78 | } | ||
| 79 | } | ||
| 80 | 24 | return true; | |
| 81 | } | ||
| 82 | 6 | bool SmyshlaevAMatMulMPI::RunSequential() { | |
| 83 | 6 | int rank = 0; | |
| 84 | 6 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 85 | 6 | std::vector<double> result; | |
| 86 | 6 | int num_elms = 0; | |
| 87 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank == 0) { |
| 88 | const auto &num_rows_a = std::get<0>(GetInput()); | ||
| 89 | const auto &mat_a = std::get<1>(GetInput()); | ||
| 90 | const auto &num_rows_b = std::get<2>(GetInput()); | ||
| 91 | const auto &mat_b = std::get<3>(GetInput()); | ||
| 92 | 3 | const int num_cols_b = static_cast<int>(mat_b.size()) / num_rows_b; | |
| 93 | const int num_cols_a = num_rows_b; | ||
| 94 | |||
| 95 | 3 | num_elms = num_rows_a * num_cols_b; | |
| 96 |
1/4✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
3 | result.resize(num_elms, 0.0); |
| 97 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 3 times.
|
15 | for (int i = 0; i < num_rows_a; ++i) { |
| 98 |
2/2✓ Branch 0 taken 16 times.
✓ Branch 1 taken 12 times.
|
28 | for (int j = 0; j < num_cols_b; ++j) { |
| 99 | double sum = 0.0; | ||
| 100 |
2/2✓ Branch 0 taken 126 times.
✓ Branch 1 taken 16 times.
|
142 | for (int k = 0; k < num_cols_a; ++k) { |
| 101 | 126 | sum += mat_a[(i * num_cols_a) + k] * mat_b_transposed_[(j * num_cols_a) + k]; | |
| 102 | } | ||
| 103 | 16 | result[(i * num_cols_b) + j] = sum; | |
| 104 | } | ||
| 105 | } | ||
| 106 | } | ||
| 107 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Bcast(&num_elms, 1, MPI_INT, 0, MPI_COMM_WORLD); |
| 108 | |||
| 109 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank != 0) { |
| 110 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | result.resize(num_elms); |
| 111 | } | ||
| 112 | |||
| 113 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Bcast(result.data(), num_elms, MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 114 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | GetOutput() = result; |
| 115 | |||
| 116 | 6 | return true; | |
| 117 | } | ||
| 118 | |||
| 119 | 24 | void SmyshlaevAMatMulMPI::BroadcastDimensions(int &rows_a, int &cols_a, int &cols_b) { | |
| 120 | 24 | int rank = 0; | |
| 121 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 122 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 123 | 12 | rows_a = std::get<0>(GetInput()); | |
| 124 | const auto &mat_a = std::get<1>(GetInput()); | ||
| 125 | 12 | cols_a = static_cast<int>(mat_a.size()) / rows_a; | |
| 126 | const auto &num_rows_b = std::get<2>(GetInput()); | ||
| 127 | const auto &mat_b = std::get<3>(GetInput()); | ||
| 128 | 12 | cols_b = static_cast<int>(mat_b.size()) / num_rows_b; | |
| 129 | } | ||
| 130 | 24 | MPI_Bcast(&rows_a, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 131 | 24 | MPI_Bcast(&cols_a, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 132 | 24 | MPI_Bcast(&cols_b, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 133 | 24 | } | |
| 134 | |||
| 135 | 18 | void SmyshlaevAMatMulMPI::RingShiftAlgorithm(int rank, int size, int my_rows_a, int num_cols_a, int num_cols_b, | |
| 136 | const std::vector<int> &counts_b, const std::vector<int> &disps_b, | ||
| 137 | std::vector<double> &local_a, std::vector<double> &local_b, | ||
| 138 | std::vector<double> &local_c) { | ||
| 139 | 18 | int max_elem_count_b = 0; | |
| 140 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 18 times.
|
54 | for (int c : counts_b) { |
| 141 | 36 | max_elem_count_b = std::max(max_elem_count_b, c); | |
| 142 | } | ||
| 143 | 18 | std::vector<double> local_b_next(max_elem_count_b); | |
| 144 | |||
| 145 | 18 | int left_neighbor = (rank - 1 + size) % size; | |
| 146 | 18 | int right_neighbor = (rank + 1) % size; | |
| 147 | |||
| 148 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 18 times.
|
54 | for (int step = 0; step < size; ++step) { |
| 149 | 36 | int b_owner_rank = (rank - step + size) % size; | |
| 150 | 36 | int current_cols_b_count = counts_b[b_owner_rank] / num_cols_a; | |
| 151 | 36 | int global_col_shift = disps_b[b_owner_rank] / num_cols_a; | |
| 152 | |||
| 153 |
2/2✓ Branch 0 taken 130 times.
✓ Branch 1 taken 36 times.
|
166 | for (int i = 0; i < my_rows_a; ++i) { |
| 154 |
2/2✓ Branch 0 taken 1208 times.
✓ Branch 1 taken 130 times.
|
1338 | for (int j = 0; j < current_cols_b_count; ++j) { |
| 155 | double sum = 0.0; | ||
| 156 |
2/2✓ Branch 0 taken 33984 times.
✓ Branch 1 taken 1208 times.
|
35192 | for (int k = 0; k < num_cols_a; ++k) { |
| 157 | 33984 | sum += local_a[(i * num_cols_a) + k] * local_b[(j * num_cols_a) + k]; | |
| 158 | } | ||
| 159 | 1208 | local_c[(i * num_cols_b) + (global_col_shift + j)] = sum; | |
| 160 | } | ||
| 161 | } | ||
| 162 | |||
| 163 | int send_count = counts_b[b_owner_rank]; | ||
| 164 | 36 | int recv_owner_rank = (rank - (step + 1) + size) % size; | |
| 165 |
1/2✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
|
36 | int recv_count = counts_b[recv_owner_rank]; |
| 166 | |||
| 167 |
1/2✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
|
36 | MPI_Sendrecv(local_b.data(), send_count, MPI_DOUBLE, right_neighbor, 0, local_b_next.data(), recv_count, MPI_DOUBLE, |
| 168 | left_neighbor, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); | ||
| 169 | |||
| 170 | 36 | local_b.assign(local_b_next.begin(), local_b_next.begin() + recv_count); | |
| 171 | } | ||
| 172 | 18 | } | |
| 173 | |||
| 174 | 18 | void SmyshlaevAMatMulMPI::GatherAndBroadcastResults(int rank, int size, int rows_a, int cols_a, int cols_b, | |
| 175 | const std::vector<int> &counts_a, | ||
| 176 | const std::vector<double> &local_c) { | ||
| 177 | 18 | std::vector<double> final_res; | |
| 178 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<int> recvcounts_c(size); |
| 179 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<int> disps_c(size); |
| 180 | |||
| 181 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | if (rank == 0) { |
| 182 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | final_res.resize(static_cast<size_t>(rows_a) * cols_b); |
| 183 | int offset = 0; | ||
| 184 |
2/2✓ Branch 0 taken 18 times.
✓ Branch 1 taken 9 times.
|
27 | for (int i = 0; i < size; ++i) { |
| 185 | 18 | int r_rows = counts_a[i] / cols_a; | |
| 186 | 18 | recvcounts_c[i] = r_rows * cols_b; | |
| 187 | 18 | disps_c[i] = offset; | |
| 188 | 18 | offset += recvcounts_c[i]; | |
| 189 | } | ||
| 190 | } | ||
| 191 | |||
| 192 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | MPI_Gatherv(local_c.data(), static_cast<int>(local_c.size()), MPI_DOUBLE, final_res.data(), recvcounts_c.data(), |
| 193 | disps_c.data(), MPI_DOUBLE, 0, MPI_COMM_WORLD); | ||
| 194 | |||
| 195 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | if (rank != 0) { |
| 196 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | final_res.resize(static_cast<size_t>(rows_a) * cols_b); |
| 197 | } | ||
| 198 | |||
| 199 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | MPI_Bcast(final_res.data(), rows_a * cols_b, MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 200 | GetOutput() = final_res; | ||
| 201 | 18 | } | |
| 202 | |||
| 203 | 24 | bool SmyshlaevAMatMulMPI::RunImpl() { | |
| 204 | 24 | int rank = 0; | |
| 205 | 24 | int size = 0; | |
| 206 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 207 | 24 | MPI_Comm_size(MPI_COMM_WORLD, &size); | |
| 208 | |||
| 209 | 24 | int num_rows_a = 0; | |
| 210 | 24 | int num_cols_a = 0; | |
| 211 | 24 | int num_cols_b = 0; | |
| 212 | |||
| 213 | 24 | BroadcastDimensions(num_rows_a, num_cols_a, num_cols_b); | |
| 214 | |||
| 215 |
4/4✓ Branch 0 taken 20 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
✓ Branch 3 taken 18 times.
|
24 | if (size > num_rows_a || size > num_cols_b) { |
| 216 | 6 | return RunSequential(); | |
| 217 | } | ||
| 218 | |||
| 219 | 18 | std::vector<int> sendcounts_a(size); | |
| 220 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<int> offsets_a(size); |
| 221 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<int> sendcounts_b(size); |
| 222 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<int> offsets_b(size); |
| 223 | |||
| 224 | 18 | CalculateDistribution(num_rows_a, size, sendcounts_a, offsets_a); | |
| 225 | |||
| 226 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 18 times.
|
54 | for (int i = 0; i < size; ++i) { |
| 227 | 36 | sendcounts_a[i] *= num_cols_a; | |
| 228 | 36 | offsets_a[i] *= num_cols_a; | |
| 229 | } | ||
| 230 | |||
| 231 | 18 | CalculateDistribution(num_cols_b, size, sendcounts_b, offsets_b); | |
| 232 | |||
| 233 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 18 times.
|
54 | for (int i = 0; i < size; ++i) { |
| 234 | 36 | sendcounts_b[i] *= num_cols_a; | |
| 235 | 36 | offsets_b[i] *= num_cols_a; | |
| 236 | } | ||
| 237 | |||
| 238 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | int my_rows_a = sendcounts_a[rank] / num_cols_a; |
| 239 | |||
| 240 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<double> local_a(sendcounts_a[rank]); |
| 241 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<double> local_b(sendcounts_b[rank]); |
| 242 | |||
| 243 | 18 | int max_elem_count_b = 0; | |
| 244 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 18 times.
|
54 | for (int c : sendcounts_b) { |
| 245 | 36 | max_elem_count_b = std::max(max_elem_count_b, c); | |
| 246 | } | ||
| 247 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<double> local_b_next(max_elem_count_b); |
| 248 | |||
| 249 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | const double *sendbuf_a = (rank == 0) ? std::get<1>(GetInput()).data() : nullptr; |
| 250 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | MPI_Scatterv(sendbuf_a, sendcounts_a.data(), offsets_a.data(), MPI_DOUBLE, local_a.data(), sendcounts_a[rank], |
| 251 | MPI_DOUBLE, 0, MPI_COMM_WORLD); | ||
| 252 | |||
| 253 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | const double *sendbuf_b = (rank == 0) ? mat_b_transposed_.data() : nullptr; |
| 254 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | MPI_Scatterv(sendbuf_b, sendcounts_b.data(), offsets_b.data(), MPI_DOUBLE, local_b.data(), sendcounts_b[rank], |
| 255 | MPI_DOUBLE, 0, MPI_COMM_WORLD); | ||
| 256 | |||
| 257 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
18 | std::vector<double> local_c(static_cast<size_t>(my_rows_a) * num_cols_b, 0.0); |
| 258 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | RingShiftAlgorithm(rank, size, my_rows_a, num_cols_a, num_cols_b, sendcounts_b, offsets_b, local_a, local_b, local_c); |
| 259 | |||
| 260 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | GatherAndBroadcastResults(rank, size, num_rows_a, num_cols_a, num_cols_b, sendcounts_a, local_c); |
| 261 | return true; | ||
| 262 | } | ||
| 263 | |||
| 264 | 24 | bool SmyshlaevAMatMulMPI::PostProcessingImpl() { | |
| 265 | 24 | return true; | |
| 266 | } | ||
| 267 | |||
| 268 | } // namespace smyshlaev_a_mat_mul | ||
| 269 |