| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "klimovich_v_crs_complex_mat_mul/all/include/ops_all.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <array> | ||
| 7 | #include <cmath> | ||
| 8 | #include <cstddef> | ||
| 9 | #include <utility> | ||
| 10 | #include <vector> | ||
| 11 | |||
| 12 | #include "klimovich_v_crs_complex_mat_mul/common/include/common.hpp" | ||
| 13 | |||
| 14 | namespace klimovich_v_crs_complex_mat_mul { | ||
| 15 | namespace { | ||
| 16 | |||
| 17 | struct RowStage { | ||
| 18 | std::vector<int> cols; | ||
| 19 | std::vector<Cplx> vals; | ||
| 20 | }; | ||
| 21 | |||
| 22 |
2/2✓ Branch 0 taken 34 times.
✓ Branch 1 taken 23 times.
|
57 | void GustavsonRow(const CrsMatrix &lhs, const CrsMatrix &rhs, int row, std::vector<Cplx> &spa, |
| 23 | std::vector<int> &touched_by_row, std::vector<int> &touched_cols, RowStage &stage) { | ||
| 24 | touched_cols.clear(); | ||
| 25 | |||
| 26 |
2/2✓ Branch 0 taken 84 times.
✓ Branch 1 taken 57 times.
|
141 | for (int lp = lhs.row_offsets[row]; lp < lhs.row_offsets[row + 1]; ++lp) { |
| 27 | 84 | const int k = lhs.col_indices[lp]; | |
| 28 | 84 | const Cplx a_ik = lhs.data[lp]; | |
| 29 |
2/2✓ Branch 0 taken 151 times.
✓ Branch 1 taken 84 times.
|
235 | for (int rq = rhs.row_offsets[k]; rq < rhs.row_offsets[k + 1]; ++rq) { |
| 30 |
2/2✓ Branch 0 taken 84 times.
✓ Branch 1 taken 67 times.
|
151 | const int j = rhs.col_indices[rq]; |
| 31 |
2/2✓ Branch 0 taken 84 times.
✓ Branch 1 taken 67 times.
|
151 | if (touched_by_row[j] != row) { |
| 32 |
1/2✓ Branch 0 taken 84 times.
✗ Branch 1 not taken.
|
84 | touched_by_row[j] = row; |
| 33 | touched_cols.push_back(j); | ||
| 34 | 84 | spa[j] = a_ik * rhs.data[rq]; | |
| 35 | } else { | ||
| 36 | spa[j] += a_ik * rhs.data[rq]; | ||
| 37 | } | ||
| 38 | } | ||
| 39 | } | ||
| 40 | |||
| 41 | std::ranges::sort(touched_cols); | ||
| 42 | |||
| 43 | stage.cols.clear(); | ||
| 44 | stage.vals.clear(); | ||
| 45 | 57 | stage.cols.reserve(touched_cols.size()); | |
| 46 | 57 | stage.vals.reserve(touched_cols.size()); | |
| 47 | |||
| 48 |
2/2✓ Branch 0 taken 84 times.
✓ Branch 1 taken 57 times.
|
141 | for (const int j : touched_cols) { |
| 49 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 80 times.
|
84 | const Cplx v = spa[j]; |
| 50 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 80 times.
|
84 | spa[j] = Cplx(0.0, 0.0); |
| 51 |
3/4✓ Branch 0 taken 4 times.
✓ Branch 1 taken 80 times.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
84 | if (std::abs(v.real()) > kZeroDropTol || std::abs(v.imag()) > kZeroDropTol) { |
| 52 | stage.cols.push_back(j); | ||
| 53 | stage.vals.push_back(v); | ||
| 54 | } | ||
| 55 | } | ||
| 56 | 57 | } | |
| 57 | |||
| 58 | void RowRange(int total_rows, int world_size, int rank, int &begin, int &end) { | ||
| 59 | 40 | const int base = total_rows / world_size; | |
| 60 | 40 | const int extra = total_rows % world_size; | |
| 61 |
2/2✓ Branch 0 taken 15 times.
✓ Branch 1 taken 5 times.
|
20 | begin = (rank * base) + std::min(rank, extra); |
| 62 |
4/4✓ Branch 0 taken 15 times.
✓ Branch 1 taken 5 times.
✓ Branch 2 taken 15 times.
✓ Branch 3 taken 5 times.
|
40 | end = begin + base + (rank < extra ? 1 : 0); |
| 63 | } | ||
| 64 | |||
| 65 | // Helpers to reduce cognitive complexity of RunImpl | ||
| 66 | 20 | void FillRowsPerProc(int lhs_n_rows, int world_size, int rank, std::vector<int> &rows_per_proc, | |
| 67 | std::vector<int> &rows_displs) { | ||
| 68 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 69 | 10 | rows_per_proc.assign(static_cast<std::size_t>(world_size), 0); | |
| 70 | 10 | rows_displs.assign(static_cast<std::size_t>(world_size), 0); | |
| 71 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 10 times.
|
30 | for (int proc = 0; proc < world_size; ++proc) { |
| 72 | int b = 0; | ||
| 73 | int e = 0; | ||
| 74 | RowRange(lhs_n_rows, world_size, proc, b, e); | ||
| 75 | 20 | rows_per_proc[proc] = e - b; | |
| 76 | 20 | rows_displs[proc] = b; | |
| 77 | } | ||
| 78 | } | ||
| 79 | 20 | } | |
| 80 | |||
| 81 | 20 | int GatherPayloadCountsAndDispls(int local_payload, int world_size, int rank, std::vector<int> &payload_counts, | |
| 82 | std::vector<int> &payload_displs) { | ||
| 83 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 84 | 10 | payload_counts.assign(static_cast<std::size_t>(world_size), 0); | |
| 85 | 10 | payload_displs.assign(static_cast<std::size_t>(world_size), 0); | |
| 86 | } | ||
| 87 | 20 | MPI_Gather(&local_payload, 1, MPI_INT, rank == 0 ? payload_counts.data() : nullptr, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 88 | |||
| 89 | int total_payload = 0; | ||
| 90 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 91 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 10 times.
|
30 | for (int proc = 0; proc < world_size; ++proc) { |
| 92 | 20 | payload_displs[proc] = total_payload; | |
| 93 | 20 | total_payload += payload_counts[proc]; | |
| 94 | } | ||
| 95 | } | ||
| 96 | 20 | return total_payload; | |
| 97 | } | ||
| 98 | |||
| 99 | 20 | void BuildPayloadCountsD(const std::vector<int> &payload_counts, const std::vector<int> &payload_displs, int world_size, | |
| 100 | int rank, std::vector<int> &payload_counts_d, std::vector<int> &payload_displs_d) { | ||
| 101 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 102 | 10 | payload_counts_d.assign(static_cast<std::size_t>(world_size), 0); | |
| 103 | 10 | payload_displs_d.assign(static_cast<std::size_t>(world_size), 0); | |
| 104 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 10 times.
|
30 | for (int proc = 0; proc < world_size; ++proc) { |
| 105 | 20 | payload_counts_d[proc] = payload_counts[proc] * 2; | |
| 106 | 20 | payload_displs_d[proc] = payload_displs[proc] * 2; | |
| 107 | } | ||
| 108 | } | ||
| 109 | 20 | } | |
| 110 | |||
| 111 | } // namespace | ||
| 112 | |||
| 113 | 40 | void KlimovichVCrsComplexMatMulAll::BroadcastOperand(CrsMatrix &m, int root) { | |
| 114 | 40 | int rank = 0; | |
| 115 | 40 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 116 | |||
| 117 | 40 | std::array<int, 3> meta{0, 0, 0}; | |
| 118 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 20 times.
|
40 | if (rank == root) { |
| 119 | 20 | meta[0] = m.n_rows; | |
| 120 | 20 | meta[1] = m.n_cols; | |
| 121 | 20 | meta[2] = static_cast<int>(m.data.size()); | |
| 122 | } | ||
| 123 | 40 | MPI_Bcast(meta.data(), 3, MPI_INT, root, MPI_COMM_WORLD); | |
| 124 | |||
| 125 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 20 times.
|
40 | if (rank != root) { |
| 126 | 20 | m.n_rows = meta[0]; | |
| 127 | 20 | m.n_cols = meta[1]; | |
| 128 | 20 | m.row_offsets.assign(static_cast<std::size_t>(meta[0]) + 1, 0); | |
| 129 | 20 | m.col_indices.assign(static_cast<std::size_t>(meta[2]), 0); | |
| 130 | 20 | m.data.assign(static_cast<std::size_t>(meta[2]), Cplx(0.0, 0.0)); | |
| 131 | } | ||
| 132 | |||
| 133 |
1/2✓ Branch 0 taken 40 times.
✗ Branch 1 not taken.
|
40 | if (meta[0] > 0) { |
| 134 | 40 | MPI_Bcast(m.row_offsets.data(), meta[0] + 1, MPI_INT, root, MPI_COMM_WORLD); | |
| 135 | } | ||
| 136 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 4 times.
|
40 | if (meta[2] > 0) { |
| 137 | 36 | MPI_Bcast(m.col_indices.data(), meta[2], MPI_INT, root, MPI_COMM_WORLD); | |
| 138 | 36 | MPI_Bcast(m.data.data(), meta[2] * 2, MPI_DOUBLE, root, MPI_COMM_WORLD); | |
| 139 | } | ||
| 140 | 40 | } | |
| 141 | |||
| 142 | 20 | void KlimovichVCrsComplexMatMulAll::ComputeLocalRows(const CrsMatrix &lhs, const CrsMatrix &rhs, int row_begin, | |
| 143 | int row_end, std::vector<int> &local_nnz_per_row, | ||
| 144 | std::vector<int> &local_cols, std::vector<Cplx> &local_vals) { | ||
| 145 | 20 | const int local_rows = row_end - row_begin; | |
| 146 | 20 | std::vector<RowStage> stages(static_cast<std::size_t>(local_rows)); | |
| 147 | |||
| 148 | 20 | #pragma omp parallel default(none) shared(lhs, rhs, stages, row_begin, row_end) | |
| 149 | { | ||
| 150 | std::vector<Cplx> spa(static_cast<std::size_t>(rhs.n_cols)); | ||
| 151 | std::vector<int> touched_by_row(static_cast<std::size_t>(rhs.n_cols), -1); | ||
| 152 | std::vector<int> touched_cols; | ||
| 153 | touched_cols.reserve(static_cast<std::size_t>(rhs.n_cols)); | ||
| 154 | |||
| 155 | #pragma omp for schedule(dynamic, 16) | ||
| 156 | for (int i = row_begin; i < row_end; ++i) { | ||
| 157 | GustavsonRow(lhs, rhs, i, spa, touched_by_row, touched_cols, stages[i - row_begin]); | ||
| 158 | } | ||
| 159 | } | ||
| 160 | |||
| 161 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | local_nnz_per_row.assign(static_cast<std::size_t>(local_rows), 0); |
| 162 | std::size_t total = 0; | ||
| 163 |
2/2✓ Branch 0 taken 57 times.
✓ Branch 1 taken 20 times.
|
77 | for (int i = 0; i < local_rows; ++i) { |
| 164 | 57 | local_nnz_per_row[i] = static_cast<int>(stages[i].cols.size()); | |
| 165 | 57 | total += stages[i].cols.size(); | |
| 166 | } | ||
| 167 | |||
| 168 | local_cols.clear(); | ||
| 169 | local_vals.clear(); | ||
| 170 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | local_cols.reserve(total); |
| 171 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | local_vals.reserve(total); |
| 172 |
2/2✓ Branch 0 taken 57 times.
✓ Branch 1 taken 20 times.
|
77 | for (auto &stage : stages) { |
| 173 |
1/2✓ Branch 1 taken 57 times.
✗ Branch 2 not taken.
|
57 | local_cols.insert(local_cols.end(), stage.cols.begin(), stage.cols.end()); |
| 174 | 57 | local_vals.insert(local_vals.end(), stage.vals.begin(), stage.vals.end()); | |
| 175 | } | ||
| 176 | 20 | } | |
| 177 | |||
| 178 |
1/2✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
|
20 | KlimovichVCrsComplexMatMulAll::KlimovichVCrsComplexMatMulAll(const InType &in) { |
| 179 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 180 | 20 | int rank = 0; | |
| 181 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
| 182 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 183 | GetInput() = in; | ||
| 184 | } | ||
| 185 | 20 | GetOutput() = CrsMatrix(); | |
| 186 | 20 | } | |
| 187 | |||
| 188 | 20 | bool KlimovichVCrsComplexMatMulAll::ValidationImpl() { | |
| 189 | 20 | int rank = 0; | |
| 190 | 20 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 191 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank != 0) { |
| 192 | return true; | ||
| 193 | } | ||
| 194 | const auto &lhs = std::get<0>(GetInput()); | ||
| 195 | const auto &rhs = std::get<1>(GetInput()); | ||
| 196 | 10 | return lhs.n_cols == rhs.n_rows; | |
| 197 | } | ||
| 198 | |||
| 199 | 20 | bool KlimovichVCrsComplexMatMulAll::PreProcessingImpl() { | |
| 200 | 20 | return true; | |
| 201 | } | ||
| 202 | |||
| 203 | 20 | bool KlimovichVCrsComplexMatMulAll::RunImpl() { | |
| 204 | 20 | int rank = 0; | |
| 205 | 20 | int world_size = 1; | |
| 206 | 20 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 207 | 20 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); | |
| 208 | |||
| 209 | 20 | CrsMatrix lhs; | |
| 210 | 20 | CrsMatrix rhs; | |
| 211 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 212 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | lhs = std::get<0>(GetInput()); |
| 213 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | rhs = std::get<1>(GetInput()); |
| 214 | } | ||
| 215 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | BroadcastOperand(lhs, 0); |
| 216 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | BroadcastOperand(rhs, 0); |
| 217 | |||
| 218 | int row_begin = 0; | ||
| 219 | int row_end = 0; | ||
| 220 |
2/2✓ Branch 0 taken 15 times.
✓ Branch 1 taken 5 times.
|
20 | RowRange(lhs.n_rows, world_size, rank, row_begin, row_end); |
| 221 | |||
| 222 | 20 | std::vector<int> local_nnz_per_row; | |
| 223 | 20 | std::vector<int> local_cols; | |
| 224 | 20 | std::vector<Cplx> local_vals; | |
| 225 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | ComputeLocalRows(lhs, rhs, row_begin, row_end, local_nnz_per_row, local_cols, local_vals); |
| 226 | |||
| 227 | 20 | std::vector<int> rows_per_proc; | |
| 228 | 20 | std::vector<int> rows_displs; | |
| 229 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | FillRowsPerProc(lhs.n_rows, world_size, rank, rows_per_proc, rows_displs); |
| 230 | |||
| 231 | 20 | std::vector<int> global_nnz_per_row; | |
| 232 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 233 |
1/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
10 | global_nnz_per_row.assign(static_cast<std::size_t>(lhs.n_rows), 0); |
| 234 | } | ||
| 235 |
5/6✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 10 times.
✓ Branch 3 taken 10 times.
✓ Branch 5 taken 20 times.
✗ Branch 6 not taken.
|
40 | MPI_Gatherv(local_nnz_per_row.data(), static_cast<int>(local_nnz_per_row.size()), MPI_INT, |
| 236 | rank == 0 ? global_nnz_per_row.data() : nullptr, rank == 0 ? rows_per_proc.data() : nullptr, | ||
| 237 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | rank == 0 ? rows_displs.data() : nullptr, MPI_INT, 0, MPI_COMM_WORLD); |
| 238 | |||
| 239 | 20 | const int local_payload = static_cast<int>(local_cols.size()); | |
| 240 | 20 | std::vector<int> payload_counts; | |
| 241 | 20 | std::vector<int> payload_displs; | |
| 242 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | int total_payload = GatherPayloadCountsAndDispls(local_payload, world_size, rank, payload_counts, payload_displs); |
| 243 | |||
| 244 | 20 | std::vector<int> all_cols; | |
| 245 | 20 | std::vector<Cplx> all_vals; | |
| 246 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 247 |
2/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
|
10 | all_cols.assign(static_cast<std::size_t>(total_payload), 0); |
| 248 |
0/2✗ Branch 0 not taken.
✗ Branch 1 not taken.
|
10 | all_vals.assign(static_cast<std::size_t>(total_payload), Cplx(0.0, 0.0)); |
| 249 | } | ||
| 250 | |||
| 251 |
5/6✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 10 times.
✓ Branch 3 taken 10 times.
✓ Branch 5 taken 20 times.
✗ Branch 6 not taken.
|
40 | MPI_Gatherv(local_cols.data(), local_payload, MPI_INT, rank == 0 ? all_cols.data() : nullptr, |
| 252 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | rank == 0 ? payload_counts.data() : nullptr, rank == 0 ? payload_displs.data() : nullptr, MPI_INT, 0, |
| 253 | MPI_COMM_WORLD); | ||
| 254 | |||
| 255 | 20 | std::vector<int> payload_counts_d; | |
| 256 | 20 | std::vector<int> payload_displs_d; | |
| 257 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | BuildPayloadCountsD(payload_counts, payload_displs, world_size, rank, payload_counts_d, payload_displs_d); |
| 258 |
5/6✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 10 times.
✓ Branch 3 taken 10 times.
✓ Branch 5 taken 20 times.
✗ Branch 6 not taken.
|
40 | MPI_Gatherv(local_vals.data(), local_payload * 2, MPI_DOUBLE, rank == 0 ? all_vals.data() : nullptr, |
| 259 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | rank == 0 ? payload_counts_d.data() : nullptr, rank == 0 ? payload_displs_d.data() : nullptr, MPI_DOUBLE, |
| 260 | 0, MPI_COMM_WORLD); | ||
| 261 | |||
| 262 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank == 0) { |
| 263 | CrsMatrix &out = GetOutput(); | ||
| 264 |
1/4✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
10 | out = CrsMatrix(lhs.n_rows, rhs.n_cols); |
| 265 |
2/2✓ Branch 0 taken 57 times.
✓ Branch 1 taken 10 times.
|
67 | for (int i = 0; i < lhs.n_rows; ++i) { |
| 266 | 57 | out.row_offsets[i + 1] = out.row_offsets[i] + global_nnz_per_row[i]; | |
| 267 | } | ||
| 268 | 10 | out.col_indices = std::move(all_cols); | |
| 269 | 10 | out.data = std::move(all_vals); | |
| 270 | } | ||
| 271 | |||
| 272 | 20 | return true; | |
| 273 | 20 | } | |
| 274 | |||
| 275 | 20 | bool KlimovichVCrsComplexMatMulAll::PostProcessingImpl() { | |
| 276 | 20 | return true; | |
| 277 | } | ||
| 278 | |||
| 279 | } // namespace klimovich_v_crs_complex_mat_mul | ||
| 280 |