| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "ermakov_a_spar_mat_mult/all/include/ops_all.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <array> | ||
| 7 | #include <complex> | ||
| 8 | #include <cstddef> | ||
| 9 | #include <cstdint> | ||
| 10 | #include <numeric> | ||
| 11 | #include <thread> | ||
| 12 | #include <utility> | ||
| 13 | #include <vector> | ||
| 14 | |||
| 15 | #include "ermakov_a_spar_mat_mult/common/include/common.hpp" | ||
| 16 | #include "task/include/task.hpp" | ||
| 17 | #include "util/include/util.hpp" | ||
| 18 | |||
| 19 | namespace ermakov_a_spar_mat_mult { | ||
| 20 | |||
| 21 | namespace { | ||
| 22 | |||
| 23 | struct LocalRowData { | ||
| 24 | std::vector<int> cols; | ||
| 25 | std::vector<std::complex<double>> vals; | ||
| 26 | }; | ||
| 27 | |||
| 28 | 4 | std::vector<int> BuildRowBounds(const MatrixCRS &a, const MatrixCRS &b, int proc_count) { | |
| 29 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
|
4 | if (proc_count <= 0) { |
| 30 | ✗ | return {}; | |
| 31 | } | ||
| 32 | |||
| 33 | 4 | std::vector<int> bounds(static_cast<std::size_t>(proc_count) + 1ULL, 0); | |
| 34 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | bounds.back() = a.rows; |
| 35 | |||
| 36 |
1/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
4 | std::vector<int> row_costs(static_cast<std::size_t>(a.rows), 0); |
| 37 | std::int64_t total_cost = 0; | ||
| 38 |
2/2✓ Branch 0 taken 63 times.
✓ Branch 1 taken 4 times.
|
67 | for (int row = 0; row < a.rows; ++row) { |
| 39 | int row_cost = 0; | ||
| 40 |
2/2✓ Branch 0 taken 704 times.
✓ Branch 1 taken 63 times.
|
767 | for (int ak = a.row_ptr[static_cast<std::size_t>(row)]; ak < a.row_ptr[static_cast<std::size_t>(row) + 1ULL]; |
| 41 | ++ak) { | ||
| 42 | 704 | const int b_row = a.col_index[static_cast<std::size_t>(ak)]; | |
| 43 | 704 | row_cost += b.row_ptr[static_cast<std::size_t>(b_row) + 1ULL] - b.row_ptr[static_cast<std::size_t>(b_row)]; | |
| 44 | } | ||
| 45 | 63 | row_costs[static_cast<std::size_t>(row)] = row_cost; | |
| 46 | 63 | total_cost += row_cost; | |
| 47 | } | ||
| 48 | |||
| 49 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if (proc_count <= 1 || total_cost == 0) { |
| 50 | ✗ | for (int proc = 0; proc <= proc_count; ++proc) { | |
| 51 | ✗ | bounds[static_cast<std::size_t>(proc)] = (proc * a.rows) / proc_count; | |
| 52 | } | ||
| 53 | return bounds; | ||
| 54 | } | ||
| 55 | |||
| 56 | int next_proc = 1; | ||
| 57 | std::int64_t prefix_cost = 0; | ||
| 58 |
3/4✓ Branch 0 taken 33 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 29 times.
✓ Branch 3 taken 4 times.
|
33 | for (int row = 0; row < a.rows && next_proc < proc_count; ++row) { |
| 59 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 25 times.
|
29 | prefix_cost += row_costs[static_cast<std::size_t>(row)]; |
| 60 | 29 | const std::int64_t target_cost = (static_cast<std::int64_t>(next_proc) * total_cost) / proc_count; | |
| 61 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 25 times.
|
29 | if (prefix_cost >= target_cost) { |
| 62 | 4 | bounds[static_cast<std::size_t>(next_proc)] = row + 1; | |
| 63 | 4 | ++next_proc; | |
| 64 | } | ||
| 65 | } | ||
| 66 | |||
| 67 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
|
4 | while (next_proc < proc_count) { |
| 68 | ✗ | bounds[static_cast<std::size_t>(next_proc)] = a.rows; | |
| 69 | ✗ | ++next_proc; | |
| 70 | } | ||
| 71 | |||
| 72 | return bounds; | ||
| 73 | } | ||
| 74 | |||
| 75 | 16 | std::vector<int> BuildCountsFromBounds(const std::vector<int> &bounds) { | |
| 76 | 16 | std::vector<int> counts(bounds.size() - 1ULL, 0); | |
| 77 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 16 times.
|
48 | for (std::size_t proc = 0; proc + 1 < bounds.size(); ++proc) { |
| 78 | 32 | counts[proc] = bounds[proc + 1] - bounds[proc]; | |
| 79 | } | ||
| 80 | 16 | return counts; | |
| 81 | } | ||
| 82 | |||
| 83 | 24 | std::vector<int> BuildDisplacements(const std::vector<int> &counts) { | |
| 84 | 24 | std::vector<int> displs(counts.size(), 0); | |
| 85 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 24 times.
|
48 | for (std::size_t proc = 1; proc < counts.size(); ++proc) { |
| 86 | 24 | displs[proc] = displs[proc - 1] + counts[proc - 1]; | |
| 87 | } | ||
| 88 | 24 | return displs; | |
| 89 | } | ||
| 90 | |||
| 91 | 4 | std::vector<int> BuildNNZCounts(const MatrixCRS &matrix, const std::vector<int> &bounds) { | |
| 92 | 4 | std::vector<int> counts(bounds.size() - 1ULL, 0); | |
| 93 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
|
12 | for (std::size_t proc = 0; proc + 1 < bounds.size(); ++proc) { |
| 94 | 8 | counts[proc] = matrix.row_ptr[static_cast<std::size_t>(bounds[proc + 1])] - | |
| 95 | 8 | matrix.row_ptr[static_cast<std::size_t>(bounds[proc])]; | |
| 96 | } | ||
| 97 | 4 | return counts; | |
| 98 | } | ||
| 99 | |||
| 100 | 48 | MPI_Datatype GetComplexBytesType() { | |
| 101 | static MPI_Datatype datatype = MPI_DATATYPE_NULL; | ||
| 102 | static bool initialized = false; | ||
| 103 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 46 times.
|
48 | if (!initialized) { |
| 104 | 2 | MPI_Type_contiguous(static_cast<int>(sizeof(std::complex<double>)), MPI_BYTE, &datatype); | |
| 105 | 2 | MPI_Type_commit(&datatype); | |
| 106 | 2 | initialized = true; | |
| 107 | } | ||
| 108 | 48 | return datatype; | |
| 109 | } | ||
| 110 | |||
| 111 | 8 | MatrixCRS ScatterRows(const MatrixCRS &matrix, const std::vector<int> &row_bounds, const std::vector<int> &nnz_counts, | |
| 112 | int rank) { | ||
| 113 | 8 | const std::vector<int> row_counts = BuildCountsFromBounds(row_bounds); | |
| 114 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const std::vector<int> row_displs = BuildDisplacements(row_counts); |
| 115 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const std::vector<int> nnz_displs = BuildDisplacements(nnz_counts); |
| 116 | |||
| 117 | 8 | MatrixCRS local; | |
| 118 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | local.rows = row_counts[static_cast<std::size_t>(rank)]; |
| 119 | 8 | local.cols = matrix.cols; | |
| 120 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | local.row_ptr.assign(static_cast<std::size_t>(local.rows) + 1ULL, 0); |
| 121 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | local.col_index.resize(static_cast<std::size_t>(nnz_counts[static_cast<std::size_t>(rank)])); |
| 122 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | local.values.resize(static_cast<std::size_t>(nnz_counts[static_cast<std::size_t>(rank)])); |
| 123 | |||
| 124 | 8 | std::vector<int> all_row_lengths; | |
| 125 | |||
| 126 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
|
8 | if (rank == 0) { |
| 127 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | all_row_lengths.resize(static_cast<std::size_t>(matrix.rows), 0); |
| 128 |
2/2✓ Branch 0 taken 63 times.
✓ Branch 1 taken 4 times.
|
67 | for (int row = 0; row < matrix.rows; ++row) { |
| 129 | 63 | all_row_lengths[static_cast<std::size_t>(row)] = | |
| 130 | 63 | matrix.row_ptr[static_cast<std::size_t>(row) + 1ULL] - matrix.row_ptr[static_cast<std::size_t>(row)]; | |
| 131 | } | ||
| 132 | } | ||
| 133 | |||
| 134 |
1/4✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
8 | std::vector<int> local_row_lengths(static_cast<std::size_t>(local.rows), 0); |
| 135 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Scatterv(all_row_lengths.data(), row_counts.data(), row_displs.data(), MPI_INT, local_row_lengths.data(), |
| 136 | local.rows, MPI_INT, 0, MPI_COMM_WORLD); | ||
| 137 | |||
| 138 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const int local_nnz = nnz_counts[static_cast<std::size_t>(rank)]; |
| 139 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Scatterv(matrix.col_index.data(), nnz_counts.data(), nnz_displs.data(), MPI_INT, local.col_index.data(), |
| 140 | local_nnz, MPI_INT, 0, MPI_COMM_WORLD); | ||
| 141 | |||
| 142 |
3/6✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
|
8 | MPI_Scatterv(matrix.values.data(), nnz_counts.data(), nnz_displs.data(), GetComplexBytesType(), local.values.data(), |
| 143 | local_nnz, GetComplexBytesType(), 0, MPI_COMM_WORLD); | ||
| 144 | |||
| 145 | int prefix = 0; | ||
| 146 |
2/2✓ Branch 0 taken 63 times.
✓ Branch 1 taken 8 times.
|
71 | for (int row = 0; row < local.rows; ++row) { |
| 147 | 63 | local.row_ptr[static_cast<std::size_t>(row)] = prefix; | |
| 148 | 63 | prefix += local_row_lengths[static_cast<std::size_t>(row)]; | |
| 149 | } | ||
| 150 |
1/2✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
|
8 | local.row_ptr[static_cast<std::size_t>(local.rows)] = prefix; |
| 151 | |||
| 152 | 8 | return local; | |
| 153 | ✗ | } | |
| 154 | |||
| 155 | 16 | void BroadcastMatrix(MatrixCRS &matrix, int rank) { | |
| 156 | 16 | std::array<int, 3> dims = {matrix.rows, matrix.cols, static_cast<int>(matrix.values.size())}; | |
| 157 | 16 | MPI_Bcast(dims.data(), static_cast<int>(dims.size()), MPI_INT, 0, MPI_COMM_WORLD); | |
| 158 | |||
| 159 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
|
16 | if (rank != 0) { |
| 160 | 8 | matrix.rows = dims[0]; | |
| 161 | 8 | matrix.cols = dims[1]; | |
| 162 | 8 | matrix.values.resize(static_cast<std::size_t>(dims[2])); | |
| 163 | 8 | matrix.col_index.resize(static_cast<std::size_t>(dims[2])); | |
| 164 | 8 | matrix.row_ptr.resize(static_cast<std::size_t>(matrix.rows) + 1ULL); | |
| 165 | } | ||
| 166 | |||
| 167 | 16 | MPI_Bcast(matrix.col_index.data(), dims[2], MPI_INT, 0, MPI_COMM_WORLD); | |
| 168 | 16 | MPI_Bcast(matrix.row_ptr.data(), matrix.rows + 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 169 | 16 | MPI_Bcast(matrix.values.data(), dims[2], GetComplexBytesType(), 0, MPI_COMM_WORLD); | |
| 170 | 16 | } | |
| 171 | |||
| 172 |
2/2✓ Branch 0 taken 42 times.
✓ Branch 1 taken 21 times.
|
63 | void AccumulateRowProducts(const MatrixCRS &a, const MatrixCRS &b, int row_index, |
| 173 | std::vector<std::complex<double>> &row_vals, std::vector<int> &row_mark, | ||
| 174 | std::vector<int> &used_cols) { | ||
| 175 | used_cols.clear(); | ||
| 176 | |||
| 177 | 767 | for (int ak = a.row_ptr[static_cast<std::size_t>(row_index)]; | |
| 178 |
2/2✓ Branch 0 taken 704 times.
✓ Branch 1 taken 63 times.
|
767 | ak < a.row_ptr[static_cast<std::size_t>(row_index) + 1ULL]; ++ak) { |
| 179 | 704 | const int b_row = a.col_index[static_cast<std::size_t>(ak)]; | |
| 180 | 704 | const auto a_val = a.values[static_cast<std::size_t>(ak)]; | |
| 181 | |||
| 182 |
2/2✓ Branch 0 taken 13539 times.
✓ Branch 1 taken 704 times.
|
14243 | for (int bk = b.row_ptr[static_cast<std::size_t>(b_row)]; bk < b.row_ptr[static_cast<std::size_t>(b_row) + 1ULL]; |
| 183 | ++bk) { | ||
| 184 |
2/2✓ Branch 0 taken 1080 times.
✓ Branch 1 taken 12459 times.
|
13539 | const int col = b.col_index[static_cast<std::size_t>(bk)]; |
| 185 | 13539 | const auto product = a_val * b.values[static_cast<std::size_t>(bk)]; | |
| 186 | |||
| 187 |
2/2✓ Branch 0 taken 1080 times.
✓ Branch 1 taken 12459 times.
|
13539 | if (row_mark[static_cast<std::size_t>(col)] != row_index) { |
| 188 |
1/2✓ Branch 0 taken 1080 times.
✗ Branch 1 not taken.
|
1080 | row_mark[static_cast<std::size_t>(col)] = row_index; |
| 189 |
1/2✓ Branch 0 taken 1080 times.
✗ Branch 1 not taken.
|
1080 | row_vals[static_cast<std::size_t>(col)] = product; |
| 190 | used_cols.push_back(col); | ||
| 191 | } else { | ||
| 192 | row_vals[static_cast<std::size_t>(col)] += product; | ||
| 193 | } | ||
| 194 | } | ||
| 195 | } | ||
| 196 | 63 | } | |
| 197 | |||
| 198 | 63 | void CollectRowValues(const std::vector<std::complex<double>> &row_vals, std::vector<int> &used_cols, | |
| 199 | LocalRowData &row) { | ||
| 200 | std::ranges::sort(used_cols); | ||
| 201 | row.cols.clear(); | ||
| 202 | row.vals.clear(); | ||
| 203 | 63 | row.cols.reserve(used_cols.size()); | |
| 204 | 63 | row.vals.reserve(used_cols.size()); | |
| 205 | |||
| 206 |
2/2✓ Branch 0 taken 1080 times.
✓ Branch 1 taken 63 times.
|
1143 | for (int col : used_cols) { |
| 207 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1080 times.
|
1080 | const auto &value = row_vals[static_cast<std::size_t>(col)]; |
| 208 | if (value != std::complex<double>(0.0, 0.0)) { | ||
| 209 | row.cols.push_back(col); | ||
| 210 | row.vals.push_back(value); | ||
| 211 | } | ||
| 212 | } | ||
| 213 | 63 | } | |
| 214 | |||
| 215 | 8 | MatrixCRS MultiplyLocalOMP(const MatrixCRS &a, const MatrixCRS &b) { | |
| 216 | 8 | MatrixCRS result; | |
| 217 | 8 | result.rows = a.rows; | |
| 218 | 8 | result.cols = b.cols; | |
| 219 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | result.row_ptr.assign(static_cast<std::size_t>(result.rows) + 1ULL, 0); |
| 220 | |||
| 221 |
2/4✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | if (a.rows == 0 || b.cols == 0) { |
| 222 | return result; | ||
| 223 | } | ||
| 224 | |||
| 225 | 8 | int rank_count = 1; | |
| 226 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Comm_size(MPI_COMM_WORLD, &rank_count); |
| 227 | |||
| 228 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | int thread_count = ppc::util::GetNumThreads(); |
| 229 | 8 | const unsigned hw_threads = std::thread::hardware_concurrency(); | |
| 230 |
2/4✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | if (hw_threads > 0U && rank_count > 1) { |
| 231 | 8 | const unsigned per_rank_cap = std::max(1U, hw_threads / static_cast<unsigned>(rank_count)); | |
| 232 | 8 | thread_count = std::min(thread_count, static_cast<int>(per_rank_cap)); | |
| 233 | } | ||
| 234 |
4/4✓ Branch 0 taken 1 times.
✓ Branch 1 taken 7 times.
✓ Branch 2 taken 7 times.
✓ Branch 3 taken 1 times.
|
9 | thread_count = std::max(1, std::min(thread_count, a.rows)); |
| 235 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | std::vector<LocalRowData> rows_data(static_cast<std::size_t>(a.rows)); |
| 236 | |||
| 237 | 8 | #pragma omp parallel default(none) shared(a, b, rows_data) num_threads(thread_count) if (thread_count > 1) | |
| 238 | { | ||
| 239 | std::vector<std::complex<double>> row_vals(static_cast<std::size_t>(b.cols), std::complex<double>(0.0, 0.0)); | ||
| 240 | std::vector<int> row_mark(static_cast<std::size_t>(b.cols), -1); | ||
| 241 | std::vector<int> used_cols; | ||
| 242 | used_cols.reserve(256); | ||
| 243 | |||
| 244 | #pragma omp for | ||
| 245 | for (int row = 0; row < a.rows; ++row) { | ||
| 246 | AccumulateRowProducts(a, b, row, row_vals, row_mark, used_cols); | ||
| 247 | CollectRowValues(row_vals, used_cols, rows_data[static_cast<std::size_t>(row)]); | ||
| 248 | } | ||
| 249 | } | ||
| 250 | |||
| 251 | int total_nnz = 0; | ||
| 252 |
2/2✓ Branch 0 taken 63 times.
✓ Branch 1 taken 8 times.
|
71 | for (int row = 0; row < result.rows; ++row) { |
| 253 | 63 | result.row_ptr[static_cast<std::size_t>(row)] = total_nnz; | |
| 254 | 63 | total_nnz += static_cast<int>(rows_data[static_cast<std::size_t>(row)].vals.size()); | |
| 255 | } | ||
| 256 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | result.row_ptr[static_cast<std::size_t>(result.rows)] = total_nnz; |
| 257 | |||
| 258 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | result.values.reserve(static_cast<std::size_t>(total_nnz)); |
| 259 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | result.col_index.reserve(static_cast<std::size_t>(total_nnz)); |
| 260 | |||
| 261 |
2/2✓ Branch 0 taken 63 times.
✓ Branch 1 taken 8 times.
|
71 | for (int row = 0; row < result.rows; ++row) { |
| 262 |
1/2✓ Branch 1 taken 63 times.
✗ Branch 2 not taken.
|
63 | const auto &row_data = rows_data[static_cast<std::size_t>(row)]; |
| 263 |
1/2✓ Branch 1 taken 63 times.
✗ Branch 2 not taken.
|
63 | result.col_index.insert(result.col_index.end(), row_data.cols.begin(), row_data.cols.end()); |
| 264 | 63 | result.values.insert(result.values.end(), row_data.vals.begin(), row_data.vals.end()); | |
| 265 | } | ||
| 266 | |||
| 267 | return result; | ||
| 268 | 8 | } | |
| 269 | |||
| 270 | 8 | void GatherMatrix(const MatrixCRS &local, MatrixCRS &global, const std::vector<int> &row_bounds, int rank, int size, | |
| 271 | int total_rows) { | ||
| 272 | 8 | const std::vector<int> row_counts = BuildCountsFromBounds(row_bounds); | |
| 273 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const std::vector<int> row_displs = BuildDisplacements(row_counts); |
| 274 |
2/6✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
8 | std::vector<int> nnz_counts(static_cast<std::size_t>(size), 0); |
| 275 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const int local_nnz = static_cast<int>(local.values.size()); |
| 276 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Gather(&local_nnz, 1, MPI_INT, nnz_counts.data(), 1, MPI_INT, 0, MPI_COMM_WORLD); |
| 277 | |||
| 278 |
1/4✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
8 | std::vector<int> local_row_lengths(static_cast<std::size_t>(local.rows), 0); |
| 279 |
2/2✓ Branch 0 taken 63 times.
✓ Branch 1 taken 8 times.
|
71 | for (int row = 0; row < local.rows; ++row) { |
| 280 | 63 | local_row_lengths[static_cast<std::size_t>(row)] = | |
| 281 | 63 | local.row_ptr[static_cast<std::size_t>(row) + 1ULL] - local.row_ptr[static_cast<std::size_t>(row)]; | |
| 282 | } | ||
| 283 | |||
| 284 | 8 | std::vector<int> nnz_displs; | |
| 285 | 8 | std::vector<int> gathered_row_lengths; | |
| 286 | 8 | std::vector<int> gathered_cols; | |
| 287 | 8 | std::vector<std::complex<double>> gathered_values; | |
| 288 | |||
| 289 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
|
8 | if (rank == 0) { |
| 290 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | nnz_displs.resize(static_cast<std::size_t>(size), 0); |
| 291 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
|
8 | for (int proc = 1; proc < size; ++proc) { |
| 292 | 4 | nnz_displs[static_cast<std::size_t>(proc)] = | |
| 293 | 4 | nnz_displs[static_cast<std::size_t>(proc - 1)] + nnz_counts[static_cast<std::size_t>(proc - 1)]; | |
| 294 | } | ||
| 295 | |||
| 296 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | gathered_row_lengths.resize(static_cast<std::size_t>(total_rows), 0); |
| 297 |
2/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
|
8 | gathered_cols.resize(static_cast<std::size_t>(std::accumulate(nnz_counts.begin(), nnz_counts.end(), 0)), 0); |
| 298 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | gathered_values.resize(gathered_cols.size()); |
| 299 | } | ||
| 300 | |||
| 301 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Gatherv(local_row_lengths.data(), local.rows, MPI_INT, gathered_row_lengths.data(), row_counts.data(), |
| 302 | row_displs.data(), MPI_INT, 0, MPI_COMM_WORLD); | ||
| 303 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Gatherv(local.col_index.data(), local_nnz, MPI_INT, gathered_cols.data(), nnz_counts.data(), nnz_displs.data(), |
| 304 | MPI_INT, 0, MPI_COMM_WORLD); | ||
| 305 | |||
| 306 |
3/6✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
|
8 | MPI_Gatherv(local.values.data(), local_nnz, GetComplexBytesType(), gathered_values.data(), nnz_counts.data(), |
| 307 | nnz_displs.data(), GetComplexBytesType(), 0, MPI_COMM_WORLD); | ||
| 308 | |||
| 309 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
|
8 | if (rank != 0) { |
| 310 | return; | ||
| 311 | } | ||
| 312 | |||
| 313 |
1/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
4 | global.row_ptr.assign(static_cast<std::size_t>(total_rows) + 1ULL, 0); |
| 314 | int prefix = 0; | ||
| 315 |
2/2✓ Branch 0 taken 63 times.
✓ Branch 1 taken 4 times.
|
67 | for (int row = 0; row < total_rows; ++row) { |
| 316 | 63 | global.row_ptr[static_cast<std::size_t>(row)] = prefix; | |
| 317 | 63 | prefix += gathered_row_lengths[static_cast<std::size_t>(row)]; | |
| 318 | } | ||
| 319 | 4 | global.row_ptr[static_cast<std::size_t>(total_rows)] = prefix; | |
| 320 | |||
| 321 | 4 | global.col_index = std::move(gathered_cols); | |
| 322 | 4 | global.values = std::move(gathered_values); | |
| 323 | } | ||
| 324 | |||
| 325 | } // namespace | ||
| 326 | |||
| 327 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | ErmakovASparMatMultALL::ErmakovASparMatMultALL(const InType &in) { |
| 328 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 329 | GetInput() = in; | ||
| 330 | 8 | } | |
| 331 | |||
| 332 | 16 | bool ErmakovASparMatMultALL::ValidateMatrix(const MatrixCRS &m) { | |
| 333 |
2/4✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 16 times.
|
16 | if (m.rows < 0 || m.cols < 0) { |
| 334 | return false; | ||
| 335 | } | ||
| 336 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
|
16 | if (m.row_ptr.size() != static_cast<std::size_t>(m.rows) + 1ULL) { |
| 337 | return false; | ||
| 338 | } | ||
| 339 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
|
16 | if (m.values.size() != m.col_index.size()) { |
| 340 | return false; | ||
| 341 | } | ||
| 342 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
|
16 | if (m.row_ptr.empty()) { |
| 343 | return false; | ||
| 344 | } | ||
| 345 | |||
| 346 | 16 | const int nnz = static_cast<int>(m.values.size()); | |
| 347 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 16 times.
|
16 | if (m.row_ptr.front() != 0 || m.row_ptr.back() != nnz) { |
| 348 | return false; | ||
| 349 | } | ||
| 350 | |||
| 351 |
2/2✓ Branch 0 taken 252 times.
✓ Branch 1 taken 16 times.
|
268 | for (int row = 0; row < m.rows; ++row) { |
| 352 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 252 times.
|
252 | if (m.row_ptr[static_cast<std::size_t>(row)] > m.row_ptr[static_cast<std::size_t>(row) + 1ULL]) { |
| 353 | return false; | ||
| 354 | } | ||
| 355 | } | ||
| 356 | |||
| 357 |
2/2✓ Branch 0 taken 2846 times.
✓ Branch 1 taken 16 times.
|
2862 | for (int idx = 0; idx < nnz; ++idx) { |
| 358 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2846 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2846 times.
|
2846 | if (m.col_index[static_cast<std::size_t>(idx)] < 0 || m.col_index[static_cast<std::size_t>(idx)] >= m.cols) { |
| 359 | return false; | ||
| 360 | } | ||
| 361 | } | ||
| 362 | |||
| 363 | return true; | ||
| 364 | } | ||
| 365 | |||
| 366 | 8 | bool ErmakovASparMatMultALL::ValidationImpl() { | |
| 367 | 8 | const auto &a = GetInput().A; | |
| 368 | 8 | const auto &b = GetInput().B; | |
| 369 |
2/4✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | return a.cols == b.rows && ValidateMatrix(a) && ValidateMatrix(b); |
| 370 | } | ||
| 371 | |||
| 372 | 8 | bool ErmakovASparMatMultALL::PreProcessingImpl() { | |
| 373 | 8 | a_ = GetInput().A; | |
| 374 | 8 | b_ = GetInput().B; | |
| 375 | 8 | c_.rows = a_.rows; | |
| 376 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | c_.cols = b_.cols; |
| 377 | c_.values.clear(); | ||
| 378 | c_.col_index.clear(); | ||
| 379 | 8 | c_.row_ptr.assign(static_cast<std::size_t>(c_.rows) + 1ULL, 0); | |
| 380 | 8 | return true; | |
| 381 | } | ||
| 382 | |||
| 383 | 8 | bool ErmakovASparMatMultALL::RunImpl() { | |
| 384 | 8 | int rank = 0; | |
| 385 | 8 | int size = 1; | |
| 386 | 8 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 387 | 8 | MPI_Comm_size(MPI_COMM_WORLD, &size); | |
| 388 | |||
| 389 |
3/4✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
8 | if (rank == 0 && a_.cols != b_.rows) { |
| 390 | return false; | ||
| 391 | } | ||
| 392 | |||
| 393 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if (size == 1) { |
| 394 | ✗ | c_ = MultiplyLocalOMP(a_, b_); | |
| 395 | ✗ | return true; | |
| 396 | } | ||
| 397 | |||
| 398 | 8 | BroadcastMatrix(b_, rank); | |
| 399 | |||
| 400 | 8 | std::vector<int> row_bounds(static_cast<std::size_t>(size) + 1ULL, 0); | |
| 401 |
1/4✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
8 | std::vector<int> nnz_counts(static_cast<std::size_t>(size), 0); |
| 402 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
|
8 | if (rank == 0) { |
| 403 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | row_bounds = BuildRowBounds(a_, b_, size); |
| 404 |
1/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
8 | nnz_counts = BuildNNZCounts(a_, row_bounds); |
| 405 | } | ||
| 406 | |||
| 407 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Bcast(row_bounds.data(), size + 1, MPI_INT, 0, MPI_COMM_WORLD); |
| 408 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Bcast(nnz_counts.data(), size, MPI_INT, 0, MPI_COMM_WORLD); |
| 409 | |||
| 410 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const MatrixCRS local_a = ScatterRows(a_, row_bounds, nnz_counts, rank); |
| 411 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | const MatrixCRS local_c = MultiplyLocalOMP(local_a, b_); |
| 412 | |||
| 413 | 8 | c_.rows = a_.rows; | |
| 414 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | c_.cols = b_.cols; |
| 415 | c_.values.clear(); | ||
| 416 | c_.col_index.clear(); | ||
| 417 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | c_.row_ptr.assign(static_cast<std::size_t>(c_.rows) + 1ULL, 0); |
| 418 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | GatherMatrix(local_c, c_, row_bounds, rank, size, a_.rows); |
| 419 | |||
| 420 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if (GetStateOfTesting() == ppc::task::StateOfTesting::kPerf) { |
| 421 | ✗ | MPI_Barrier(MPI_COMM_WORLD); | |
| 422 | return true; | ||
| 423 | } | ||
| 424 | |||
| 425 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | BroadcastMatrix(c_, rank); |
| 426 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Barrier(MPI_COMM_WORLD); |
| 427 | return true; | ||
| 428 | 8 | } | |
| 429 | |||
| 430 | 8 | bool ErmakovASparMatMultALL::PostProcessingImpl() { | |
| 431 | 8 | GetOutput() = c_; | |
| 432 | 8 | return true; | |
| 433 | } | ||
| 434 | |||
| 435 | } // namespace ermakov_a_spar_mat_mult | ||
| 436 |