| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "potashnik_m_matrix_mult_complex/all/include/ops_all.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <array> | ||
| 7 | #include <cstddef> | ||
| 8 | #include <cstdint> | ||
| 9 | #include <map> | ||
| 10 | #include <thread> | ||
| 11 | #include <utility> | ||
| 12 | #include <vector> | ||
| 13 | |||
| 14 | #include "potashnik_m_matrix_mult_complex/common/include/common.hpp" | ||
| 15 | |||
| 16 | namespace potashnik_m_matrix_mult_complex { | ||
| 17 | |||
| 18 | namespace { | ||
| 19 | |||
| 20 | using Key = std::pair<size_t, size_t>; | ||
| 21 | using LocalMap = std::map<Key, Complex>; | ||
| 22 | |||
| 23 | 20 | void BroadcastMatrix(CCSMatrix &matrix) { | |
| 24 | 20 | int rank = 0; | |
| 25 | 20 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 26 | |||
| 27 | 20 | std::array<int, 3> meta{}; | |
| 28 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 29 | 10 | meta[0] = static_cast<int>(matrix.height); | |
| 30 | 10 | meta[1] = static_cast<int>(matrix.width); | |
| 31 | 10 | meta[2] = static_cast<int>(matrix.val.size()); | |
| 32 | } | ||
| 33 | 20 | MPI_Bcast(meta.data(), 3, MPI_INT, 0, MPI_COMM_WORLD); | |
| 34 | 20 | matrix.height = static_cast<size_t>(meta[0]); | |
| 35 | 20 | matrix.width = static_cast<size_t>(meta[1]); | |
| 36 | 20 | auto count = static_cast<size_t>(meta[2]); | |
| 37 | |||
| 38 | 20 | matrix.val.resize(count); | |
| 39 | 20 | matrix.row_ind.resize(count); | |
| 40 | 20 | matrix.col_ptr.resize(count); | |
| 41 | |||
| 42 | 20 | std::vector<int> tmp_rows(count); | |
| 43 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<int> tmp_cols(count); |
| 44 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<double> re(count); |
| 45 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<double> im(count); |
| 46 | |||
| 47 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 48 |
2/2✓ Branch 0 taken 421 times.
✓ Branch 1 taken 10 times.
|
431 | for (size_t i = 0; i < count; ++i) { |
| 49 | 421 | tmp_rows[i] = static_cast<int>(matrix.row_ind[i]); | |
| 50 | 421 | tmp_cols[i] = static_cast<int>(matrix.col_ptr[i]); | |
| 51 | 421 | re[i] = matrix.val[i].real; | |
| 52 | 421 | im[i] = matrix.val[i].imaginary; | |
| 53 | } | ||
| 54 | } | ||
| 55 | |||
| 56 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Bcast(tmp_rows.data(), static_cast<int>(count), MPI_INT, 0, MPI_COMM_WORLD); |
| 57 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Bcast(tmp_cols.data(), static_cast<int>(count), MPI_INT, 0, MPI_COMM_WORLD); |
| 58 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Bcast(re.data(), static_cast<int>(count), MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 59 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Bcast(im.data(), static_cast<int>(count), MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 60 | |||
| 61 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank != 0) { |
| 62 |
2/2✓ Branch 0 taken 421 times.
✓ Branch 1 taken 10 times.
|
431 | for (size_t i = 0; i < count; ++i) { |
| 63 | 421 | matrix.row_ind[i] = static_cast<size_t>(tmp_rows[i]); | |
| 64 | 421 | matrix.col_ptr[i] = static_cast<size_t>(tmp_cols[i]); | |
| 65 | 421 | matrix.val[i] = Complex(re[i], im[i]); | |
| 66 | } | ||
| 67 | } | ||
| 68 | 20 | } | |
| 69 | |||
| 70 | 20 | void ScatterMatrixLeft(int rank, int world_size, size_t total, const CCSMatrix &matrix, std::vector<int> &sendcounts, | |
| 71 | std::vector<int> &displs, std::vector<size_t> &local_rows, std::vector<size_t> &local_cols, | ||
| 72 | std::vector<double> &local_re, std::vector<double> &local_im) { | ||
| 73 | 20 | int blocksize = static_cast<int>(total) / world_size; | |
| 74 | 20 | int leftover = static_cast<int>(total) % world_size; | |
| 75 | int offset = 0; | ||
| 76 |
2/2✓ Branch 0 taken 40 times.
✓ Branch 1 taken 20 times.
|
60 | for (int proc = 0; proc < world_size; ++proc) { |
| 77 |
2/2✓ Branch 0 taken 38 times.
✓ Branch 1 taken 2 times.
|
78 | sendcounts[proc] = blocksize + (proc < leftover ? 1 : 0); |
| 78 | 40 | displs[proc] = offset; | |
| 79 | 40 | offset += sendcounts[proc]; | |
| 80 | } | ||
| 81 | |||
| 82 | 20 | int local_count = sendcounts[rank]; | |
| 83 | 20 | local_rows.resize(local_count); | |
| 84 | 20 | local_cols.resize(local_count); | |
| 85 | 20 | local_re.resize(local_count); | |
| 86 | 20 | local_im.resize(local_count); | |
| 87 | |||
| 88 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 89 |
2/2✓ Branch 0 taken 211 times.
✓ Branch 1 taken 10 times.
|
221 | for (int i = 0; i < local_count; ++i) { |
| 90 | 211 | local_rows[i] = matrix.row_ind[i]; | |
| 91 | 211 | local_cols[i] = matrix.col_ptr[i]; | |
| 92 | 211 | local_re[i] = matrix.val[i].real; | |
| 93 | 211 | local_im[i] = matrix.val[i].imaginary; | |
| 94 | } | ||
| 95 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | for (int proc = 1; proc < world_size; ++proc) { |
| 96 | 10 | int cnt = sendcounts[proc]; | |
| 97 | 10 | int dsp = displs[proc]; | |
| 98 | 10 | std::vector<int> rows_send(cnt); | |
| 99 |
1/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
10 | std::vector<int> cols_send(cnt); |
| 100 |
1/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
10 | std::vector<double> re_buf(cnt); |
| 101 |
1/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
10 | std::vector<double> im_buf(cnt); |
| 102 |
2/2✓ Branch 0 taken 210 times.
✓ Branch 1 taken 10 times.
|
220 | for (int i = 0; i < cnt; ++i) { |
| 103 | 210 | rows_send[i] = static_cast<int>(matrix.row_ind[dsp + i]); | |
| 104 | 210 | cols_send[i] = static_cast<int>(matrix.col_ptr[dsp + i]); | |
| 105 | 210 | re_buf[i] = matrix.val[dsp + i].real; | |
| 106 | 210 | im_buf[i] = matrix.val[dsp + i].imaginary; | |
| 107 | } | ||
| 108 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | MPI_Send(rows_send.data(), cnt, MPI_INT, proc, 0, MPI_COMM_WORLD); |
| 109 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | MPI_Send(cols_send.data(), cnt, MPI_INT, proc, 1, MPI_COMM_WORLD); |
| 110 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | MPI_Send(re_buf.data(), cnt, MPI_DOUBLE, proc, 2, MPI_COMM_WORLD); |
| 111 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | MPI_Send(im_buf.data(), cnt, MPI_DOUBLE, proc, 3, MPI_COMM_WORLD); |
| 112 | } | ||
| 113 | } else { | ||
| 114 | 10 | std::vector<int> rows_recv(local_count); | |
| 115 |
2/6✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
10 | std::vector<int> cols_recv(local_count); |
| 116 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | MPI_Recv(rows_recv.data(), local_count, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
| 117 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | MPI_Recv(cols_recv.data(), local_count, MPI_INT, 0, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
| 118 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | MPI_Recv(local_re.data(), local_count, MPI_DOUBLE, 0, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
| 119 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | MPI_Recv(local_im.data(), local_count, MPI_DOUBLE, 0, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
| 120 |
2/2✓ Branch 0 taken 210 times.
✓ Branch 1 taken 10 times.
|
220 | for (int i = 0; i < local_count; ++i) { |
| 121 | 210 | local_rows[i] = static_cast<size_t>(rows_recv[i]); | |
| 122 | 210 | local_cols[i] = static_cast<size_t>(cols_recv[i]); | |
| 123 | } | ||
| 124 | } | ||
| 125 | 20 | } | |
| 126 | |||
| 127 | 20 | void GatherResult(int rank, int world_size, const std::vector<size_t> &rows, const std::vector<size_t> &cols, | |
| 128 | const std::vector<double> &re_vals, const std::vector<double> &im_vals, size_t height_left, | ||
| 129 | size_t width_right, CCSMatrix &output) { | ||
| 130 | 20 | int local_count = static_cast<int>(rows.size()); | |
| 131 |
1/2✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
|
20 | std::vector<int> all_counts(world_size); |
| 132 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Gather(&local_count, 1, MPI_INT, all_counts.data(), 1, MPI_INT, 0, MPI_COMM_WORLD); |
| 133 | |||
| 134 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<int> displs(world_size, 0); |
| 135 | int total_count = 0; | ||
| 136 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 137 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 10 times.
|
30 | for (int i = 0; i < world_size; ++i) { |
| 138 | 20 | displs[i] = total_count; | |
| 139 | 20 | total_count += all_counts[i]; | |
| 140 | } | ||
| 141 | } | ||
| 142 | |||
| 143 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<int> all_rows(total_count); |
| 144 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<int> all_cols(total_count); |
| 145 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<double> all_re(total_count); |
| 146 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<double> all_im(total_count); |
| 147 | |||
| 148 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<int> rows_int(local_count); |
| 149 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<int> cols_int(local_count); |
| 150 |
2/2✓ Branch 0 taken 858 times.
✓ Branch 1 taken 20 times.
|
878 | for (int i = 0; i < local_count; ++i) { |
| 151 | 858 | rows_int[i] = static_cast<int>(rows[i]); | |
| 152 | 858 | cols_int[i] = static_cast<int>(cols[i]); | |
| 153 | } | ||
| 154 | |||
| 155 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Gatherv(rows_int.data(), local_count, MPI_INT, all_rows.data(), all_counts.data(), displs.data(), MPI_INT, 0, |
| 156 | MPI_COMM_WORLD); | ||
| 157 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Gatherv(cols_int.data(), local_count, MPI_INT, all_cols.data(), all_counts.data(), displs.data(), MPI_INT, 0, |
| 158 | MPI_COMM_WORLD); | ||
| 159 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Gatherv(re_vals.data(), local_count, MPI_DOUBLE, all_re.data(), all_counts.data(), displs.data(), MPI_DOUBLE, 0, |
| 160 | MPI_COMM_WORLD); | ||
| 161 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Gatherv(im_vals.data(), local_count, MPI_DOUBLE, all_im.data(), all_counts.data(), displs.data(), MPI_DOUBLE, 0, |
| 162 | MPI_COMM_WORLD); | ||
| 163 | |||
| 164 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 165 | std::map<Key, Complex> buffer; | ||
| 166 |
2/2✓ Branch 0 taken 858 times.
✓ Branch 1 taken 10 times.
|
868 | for (int i = 0; i < total_count; ++i) { |
| 167 |
1/2✓ Branch 1 taken 858 times.
✗ Branch 2 not taken.
|
858 | buffer[{static_cast<size_t>(all_rows[i]), static_cast<size_t>(all_cols[i])}] += Complex(all_re[i], all_im[i]); |
| 168 | } | ||
| 169 | 10 | output.height = height_left; | |
| 170 | 10 | output.width = width_right; | |
| 171 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | output.val.reserve(buffer.size()); |
| 172 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | output.row_ind.reserve(buffer.size()); |
| 173 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | output.col_ptr.reserve(buffer.size()); |
| 174 |
2/2✓ Branch 0 taken 429 times.
✓ Branch 1 taken 10 times.
|
439 | for (const auto &[key, value] : buffer) { |
| 175 |
1/2✓ Branch 0 taken 429 times.
✗ Branch 1 not taken.
|
429 | output.row_ind.push_back(key.first); |
| 176 |
1/2✓ Branch 0 taken 429 times.
✗ Branch 1 not taken.
|
429 | output.col_ptr.push_back(key.second); |
| 177 | output.val.push_back(value); | ||
| 178 | } | ||
| 179 | } | ||
| 180 | 20 | } | |
| 181 | |||
| 182 | 80 | void ProcessChunk(size_t begin, size_t end, const CCSMatrix &matrix_right, const std::vector<Complex> &val_left, | |
| 183 | const std::vector<size_t> &row_ind_left, const std::vector<size_t> &col_ptr_left, | ||
| 184 | LocalMap &local_buffer) { | ||
| 185 | const auto &val_right = matrix_right.val; | ||
| 186 | const auto &row_ind_right = matrix_right.row_ind; | ||
| 187 | const auto &col_ptr_right = matrix_right.col_ptr; | ||
| 188 | |||
| 189 |
2/2✓ Branch 0 taken 421 times.
✓ Branch 1 taken 80 times.
|
501 | for (size_t i = begin; i < end; ++i) { |
| 190 | 421 | size_t row_left = row_ind_left[i]; | |
| 191 | 421 | size_t col_left = col_ptr_left[i]; | |
| 192 | 421 | Complex left_val = val_left[i]; | |
| 193 | |||
| 194 |
2/2✓ Branch 0 taken 24088 times.
✓ Branch 1 taken 421 times.
|
24509 | for (size_t j = 0; j < matrix_right.Count(); ++j) { |
| 195 | 24088 | size_t row_right = row_ind_right[j]; | |
| 196 | 24088 | size_t col_right = col_ptr_right[j]; | |
| 197 | 24088 | Complex right_val = val_right[j]; | |
| 198 | |||
| 199 |
2/2✓ Branch 0 taken 2984 times.
✓ Branch 1 taken 21104 times.
|
24088 | if (col_left == row_right) { |
| 200 | 2984 | local_buffer[{row_left, col_right}] += left_val * right_val; | |
| 201 | } | ||
| 202 | } | ||
| 203 | } | ||
| 204 | 80 | } | |
| 205 | |||
| 206 | 20 | LocalMap ComputeLocalChunk(const CCSMatrix &matrix_left, const CCSMatrix &matrix_right) { | |
| 207 | 20 | const auto &val_left = matrix_left.val; | |
| 208 | 20 | const auto &row_ind_left = matrix_left.row_ind; | |
| 209 | 20 | const auto &col_ptr_left = matrix_left.col_ptr; | |
| 210 | |||
| 211 | 20 | size_t left_count = matrix_left.Count(); | |
| 212 | 20 | size_t num_threads = std::thread::hardware_concurrency(); | |
| 213 | if (num_threads == 0) { | ||
| 214 | num_threads = 1; | ||
| 215 | } | ||
| 216 | |||
| 217 | 20 | std::vector<LocalMap> local_buffers(num_threads); | |
| 218 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | std::vector<std::thread> threads(num_threads); |
| 219 | 20 | size_t chunk = (left_count + num_threads - 1) / num_threads; | |
| 220 | |||
| 221 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 20 times.
|
100 | for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { |
| 222 | 80 | size_t begin = thread_idx * chunk; | |
| 223 |
1/2✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
|
80 | size_t end = std::min(begin + chunk, left_count); |
| 224 | 80 | threads[thread_idx] = std::thread([&, thread_idx, begin, end]() { | |
| 225 | 80 | ProcessChunk(begin, end, matrix_right, val_left, row_ind_left, col_ptr_left, local_buffers[thread_idx]); | |
| 226 |
1/2✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
|
80 | }); |
| 227 | } | ||
| 228 | |||
| 229 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 20 times.
|
100 | for (auto &th : threads) { |
| 230 |
1/2✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
|
80 | th.join(); |
| 231 | } | ||
| 232 | |||
| 233 | LocalMap result; | ||
| 234 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 20 times.
|
100 | for (const auto &local : local_buffers) { |
| 235 |
2/2✓ Branch 0 taken 2733 times.
✓ Branch 1 taken 80 times.
|
2813 | for (const auto &[key, value] : local) { |
| 236 |
1/2✓ Branch 1 taken 2733 times.
✗ Branch 2 not taken.
|
2733 | result[key] += value; |
| 237 | } | ||
| 238 | } | ||
| 239 | 20 | return result; | |
| 240 | 20 | } | |
| 241 | |||
| 242 | } // namespace | ||
| 243 | |||
| 244 |
1/2✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
|
20 | PotashnikMMatrixMultComplexALL::PotashnikMMatrixMultComplexALL(const InType &in) { |
| 245 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 246 | 20 | int rank = 0; | |
| 247 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
| 248 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 249 | GetInput() = in; | ||
| 250 | } | ||
| 251 | 20 | } | |
| 252 | |||
| 253 | 20 | bool PotashnikMMatrixMultComplexALL::ValidationImpl() { | |
| 254 | 20 | int rank = 0; | |
| 255 | 20 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 256 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank != 0) { |
| 257 | return true; | ||
| 258 | } | ||
| 259 | const auto &matrix_left = std::get<0>(GetInput()); | ||
| 260 | const auto &matrix_right = std::get<1>(GetInput()); | ||
| 261 | 10 | return matrix_left.width == matrix_right.height; | |
| 262 | } | ||
| 263 | |||
| 264 | 20 | bool PotashnikMMatrixMultComplexALL::PreProcessingImpl() { | |
| 265 | 20 | return true; | |
| 266 | } | ||
| 267 | |||
| 268 | 20 | bool PotashnikMMatrixMultComplexALL::RunImpl() { | |
| 269 | 20 | int rank = 0; | |
| 270 | 20 | int world_size = 1; | |
| 271 | 20 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 272 | 20 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); | |
| 273 | |||
| 274 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | CCSMatrix matrix_right = (rank == 0) ? std::get<1>(GetInput()) : CCSMatrix{}; |
| 275 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | BroadcastMatrix(matrix_right); |
| 276 | |||
| 277 | 20 | std::array<uint64_t, 3> meta{}; | |
| 278 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 279 | const auto &ml = std::get<0>(GetInput()); | ||
| 280 | 10 | meta[0] = static_cast<uint64_t>(ml.Count()); | |
| 281 | 10 | meta[1] = static_cast<uint64_t>(ml.height); | |
| 282 | 10 | meta[2] = static_cast<uint64_t>(ml.width); | |
| 283 | } | ||
| 284 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Bcast(meta.data(), 3, MPI_UINT64_T, 0, MPI_COMM_WORLD); |
| 285 | 20 | auto total = static_cast<size_t>(meta[0]); | |
| 286 | 20 | auto height_left = static_cast<size_t>(meta[1]); | |
| 287 | |||
| 288 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | std::vector<int> sendcounts(world_size); |
| 289 |
1/4✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
20 | std::vector<int> displs(world_size); |
| 290 | 20 | std::vector<size_t> local_rows; | |
| 291 | 20 | std::vector<size_t> local_cols; | |
| 292 | 20 | std::vector<double> local_re; | |
| 293 | 20 | std::vector<double> local_im; | |
| 294 | |||
| 295 | 20 | const CCSMatrix empty{}; | |
| 296 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | const CCSMatrix &matrix_left_ref = (rank == 0) ? std::get<0>(GetInput()) : empty; |
| 297 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | ScatterMatrixLeft(rank, world_size, total, matrix_left_ref, sendcounts, displs, local_rows, local_cols, local_re, |
| 298 | local_im); | ||
| 299 | |||
| 300 | 20 | CCSMatrix local_left; | |
| 301 | 20 | local_left.height = height_left; | |
| 302 | 20 | local_left.width = matrix_right.height; | |
| 303 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | local_left.row_ind = local_rows; |
| 304 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | local_left.col_ptr = local_cols; |
| 305 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | local_left.val.resize(local_rows.size()); |
| 306 |
2/2✓ Branch 0 taken 421 times.
✓ Branch 1 taken 20 times.
|
441 | for (size_t i = 0; i < local_rows.size(); ++i) { |
| 307 | 421 | local_left.val[i] = Complex(local_re[i], local_im[i]); | |
| 308 | } | ||
| 309 | |||
| 310 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | auto local_result = ComputeLocalChunk(local_left, matrix_right); |
| 311 | |||
| 312 | 20 | std::vector<size_t> res_rows; | |
| 313 | 20 | std::vector<size_t> res_cols; | |
| 314 | 20 | std::vector<double> res_re; | |
| 315 | 20 | std::vector<double> res_im; | |
| 316 |
2/2✓ Branch 0 taken 858 times.
✓ Branch 1 taken 20 times.
|
878 | for (const auto &[key, value] : local_result) { |
| 317 |
2/2✓ Branch 0 taken 726 times.
✓ Branch 1 taken 132 times.
|
858 | res_rows.push_back(key.first); |
| 318 |
2/2✓ Branch 0 taken 726 times.
✓ Branch 1 taken 132 times.
|
858 | res_cols.push_back(key.second); |
| 319 |
2/2✓ Branch 0 taken 726 times.
✓ Branch 1 taken 132 times.
|
858 | res_re.push_back(value.real); |
| 320 |
2/2✓ Branch 0 taken 726 times.
✓ Branch 1 taken 132 times.
|
858 | res_im.push_back(value.imaginary); |
| 321 | } | ||
| 322 | |||
| 323 | 20 | CCSMatrix output; | |
| 324 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | GatherResult(rank, world_size, res_rows, res_cols, res_re, res_im, height_left, matrix_right.width, output); |
| 325 | |||
| 326 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 327 | 10 | GetOutput() = std::move(output); | |
| 328 | } | ||
| 329 | 20 | return true; | |
| 330 | 60 | } | |
| 331 | |||
| 332 | 20 | bool PotashnikMMatrixMultComplexALL::PostProcessingImpl() { | |
| 333 | 20 | return true; | |
| 334 | } | ||
| 335 | |||
| 336 | } // namespace potashnik_m_matrix_mult_complex | ||
| 337 |