| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "pikhotskiy_r_multiplication_of_sparse_matrices/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <cmath> | ||
| 6 | #include <cstddef> | ||
| 7 | #include <utility> | ||
| 8 | #include <vector> | ||
| 9 | |||
| 10 | #include "pikhotskiy_r_multiplication_of_sparse_matrices/common/include/common.hpp" | ||
| 11 | |||
| 12 | namespace pikhotskiy_r_multiplication_of_sparse_matrices { | ||
| 13 | |||
| 14 | namespace { | ||
| 15 | |||
| 16 | 207 | double ComputeRowColProduct(const SparseMatrixCRS &mat_a, const SparseMatrixCRS &mat_bt, int row_a, int row_bt) { | |
| 17 | double sum = 0.0; | ||
| 18 | 207 | int a_idx = mat_a.row_ptr[row_a]; | |
| 19 | 207 | int a_end = mat_a.row_ptr[row_a + 1]; | |
| 20 | 207 | int bt_idx = mat_bt.row_ptr[row_bt]; | |
| 21 | 207 | int bt_end = mat_bt.row_ptr[row_bt + 1]; | |
| 22 | |||
| 23 |
2/2✓ Branch 0 taken 459 times.
✓ Branch 1 taken 207 times.
|
666 | while (a_idx < a_end && bt_idx < bt_end) { |
| 24 |
2/2✓ Branch 0 taken 120 times.
✓ Branch 1 taken 339 times.
|
459 | int a_col = mat_a.col_indices[a_idx]; |
| 25 | 459 | int bt_col = mat_bt.col_indices[bt_idx]; | |
| 26 |
2/2✓ Branch 0 taken 120 times.
✓ Branch 1 taken 339 times.
|
459 | if (a_col == bt_col) { |
| 27 | 120 | sum += mat_a.values[a_idx] * mat_bt.values[bt_idx]; | |
| 28 | 120 | ++a_idx; | |
| 29 | 120 | ++bt_idx; | |
| 30 |
2/2✓ Branch 0 taken 168 times.
✓ Branch 1 taken 171 times.
|
339 | } else if (a_col < bt_col) { |
| 31 | 168 | ++a_idx; | |
| 32 | } else { | ||
| 33 | 171 | ++bt_idx; | |
| 34 | } | ||
| 35 | } | ||
| 36 | 207 | return sum; | |
| 37 | } | ||
| 38 | |||
| 39 | 12 | void ComputeDisplacements(int size, const std::vector<int> &all_nnz, const std::vector<int> &all_num_rows, | |
| 40 | std::vector<int> &nnz_displs, std::vector<int> &row_displs, int &total_nnz) { | ||
| 41 | 12 | nnz_displs[0] = 0; | |
| 42 | 12 | row_displs[0] = 0; | |
| 43 | 12 | total_nnz = 0; | |
| 44 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 12 times.
|
36 | for (int i = 0; i < size; ++i) { |
| 45 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (i > 0) { |
| 46 | 12 | nnz_displs[i] = nnz_displs[i - 1] + all_nnz[i - 1]; | |
| 47 | 12 | row_displs[i] = row_displs[i - 1] + all_num_rows[i - 1]; | |
| 48 | } | ||
| 49 | 24 | total_nnz += all_nnz[i]; | |
| 50 | } | ||
| 51 | 12 | } | |
| 52 | |||
| 53 | 12 | void BuildResultRowPtr(SparseMatrixCRS &result, int size, const std::vector<int> &all_nnz, | |
| 54 | const std::vector<int> &all_num_rows, const std::vector<int> &all_row_ptr_shifted) { | ||
| 55 | 12 | result.row_ptr[0] = 0; | |
| 56 | int current_offset = 0; | ||
| 57 | int row_idx = 0; | ||
| 58 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 12 times.
|
36 | for (int proc = 0; proc < size; ++proc) { |
| 59 |
2/2✓ Branch 0 taken 43 times.
✓ Branch 1 taken 24 times.
|
67 | for (int ii = 0; ii < all_num_rows[proc]; ++ii) { |
| 60 | 43 | result.row_ptr[row_idx + 1] = all_row_ptr_shifted[row_idx] + current_offset; | |
| 61 | ++row_idx; | ||
| 62 | } | ||
| 63 | 24 | current_offset += all_nnz[proc]; | |
| 64 | } | ||
| 65 | 12 | } | |
| 66 | |||
| 67 | } // namespace | ||
| 68 | |||
| 69 |
1/2✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
|
24 | SparseMatrixMultiplicationMPI::SparseMatrixMultiplicationMPI(const InType &in) { |
| 70 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 71 | 24 | int rank = 0; | |
| 72 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
| 73 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 74 | GetInput() = in; | ||
| 75 | } | ||
| 76 | 24 | } | |
| 77 | |||
| 78 | 72 | void SparseMatrixMultiplicationMPI::BroadcastSparseMatrix(SparseMatrixCRS &matrix, int root) { | |
| 79 | 72 | MPI_Bcast(&matrix.rows, 1, MPI_INT, root, MPI_COMM_WORLD); | |
| 80 | 72 | MPI_Bcast(&matrix.cols, 1, MPI_INT, root, MPI_COMM_WORLD); | |
| 81 | |||
| 82 | 72 | int nnz = static_cast<int>(matrix.values.size()); | |
| 83 | 72 | MPI_Bcast(&nnz, 1, MPI_INT, root, MPI_COMM_WORLD); | |
| 84 | |||
| 85 | 72 | int rank = 0; | |
| 86 | 72 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 87 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 36 times.
|
72 | if (rank != root) { |
| 88 | 36 | matrix.values.resize(nnz); | |
| 89 | 36 | matrix.col_indices.resize(nnz); | |
| 90 | 36 | matrix.row_ptr.resize(matrix.rows + 1); | |
| 91 | } | ||
| 92 | |||
| 93 |
2/2✓ Branch 0 taken 68 times.
✓ Branch 1 taken 4 times.
|
72 | if (nnz > 0) { |
| 94 | 68 | MPI_Bcast(matrix.values.data(), nnz, MPI_DOUBLE, root, MPI_COMM_WORLD); | |
| 95 | 68 | MPI_Bcast(matrix.col_indices.data(), nnz, MPI_INT, root, MPI_COMM_WORLD); | |
| 96 | } | ||
| 97 | 72 | MPI_Bcast(matrix.row_ptr.data(), matrix.rows + 1, MPI_INT, root, MPI_COMM_WORLD); | |
| 98 | 72 | } | |
| 99 | |||
| 100 | 24 | bool SparseMatrixMultiplicationMPI::ValidationImpl() { | |
| 101 | 24 | int rank = 0; | |
| 102 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 103 | 24 | int error_flag = 0; | |
| 104 | |||
| 105 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 106 | const auto &mat_a = std::get<0>(GetInput()); | ||
| 107 | const auto &mat_b = std::get<1>(GetInput()); | ||
| 108 | |||
| 109 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if (mat_a.cols != mat_b.rows) { |
| 110 | ✗ | error_flag = 1; | |
| 111 | } | ||
| 112 |
4/8✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
|
12 | if (mat_a.rows <= 0 || mat_a.cols <= 0 || mat_b.rows <= 0 || mat_b.cols <= 0) { |
| 113 | ✗ | error_flag = 1; | |
| 114 | } | ||
| 115 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if (mat_a.row_ptr.size() != static_cast<std::size_t>(mat_a.rows) + 1) { |
| 116 | ✗ | error_flag = 1; | |
| 117 | } | ||
| 118 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if (mat_b.row_ptr.size() != static_cast<std::size_t>(mat_b.rows) + 1) { |
| 119 | ✗ | error_flag = 1; | |
| 120 | } | ||
| 121 | } | ||
| 122 | |||
| 123 | 24 | MPI_Bcast(&error_flag, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 124 | 24 | return error_flag == 0; | |
| 125 | } | ||
| 126 | |||
| 127 | 24 | bool SparseMatrixMultiplicationMPI::PreProcessingImpl() { | |
| 128 | 24 | int rank = 0; | |
| 129 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 130 | |||
| 131 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 132 | 12 | mat_a_ = std::get<0>(GetInput()); | |
| 133 | 12 | mat_b_ = std::get<1>(GetInput()); | |
| 134 | 12 | mat_b_transposed_ = TransposeCRS(mat_b_); | |
| 135 | } | ||
| 136 | |||
| 137 | 24 | BroadcastSparseMatrix(mat_a_, 0); | |
| 138 | 24 | BroadcastSparseMatrix(mat_b_transposed_, 0); | |
| 139 | |||
| 140 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 141 | 12 | mat_b_.rows = std::get<1>(GetInput()).rows; | |
| 142 | 12 | mat_b_.cols = std::get<1>(GetInput()).cols; | |
| 143 | } | ||
| 144 | 24 | MPI_Bcast(&mat_b_.rows, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 145 | 24 | MPI_Bcast(&mat_b_.cols, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 146 | |||
| 147 | 24 | return true; | |
| 148 | } | ||
| 149 | |||
| 150 | 24 | void SparseMatrixMultiplicationMPI::GatherResults(const SparseMatrixCRS &local_result, int my_num_rows) { | |
| 151 | 24 | int rank = 0; | |
| 152 | 24 | int size = 0; | |
| 153 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 154 | 24 | MPI_Comm_size(MPI_COMM_WORLD, &size); | |
| 155 | |||
| 156 | 24 | int local_nnz = static_cast<int>(local_result.values.size()); | |
| 157 | 24 | std::vector<int> all_nnz(size); | |
| 158 |
2/6✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
24 | std::vector<int> all_num_rows(size); |
| 159 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | MPI_Gather(&local_nnz, 1, MPI_INT, all_nnz.data(), 1, MPI_INT, 0, MPI_COMM_WORLD); |
| 160 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | MPI_Gather(&my_num_rows, 1, MPI_INT, all_num_rows.data(), 1, MPI_INT, 0, MPI_COMM_WORLD); |
| 161 | |||
| 162 |
1/4✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
24 | std::vector<int> nnz_displs(size); |
| 163 |
1/4✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
24 | std::vector<int> row_displs(size); |
| 164 | 24 | int total_nnz = 0; | |
| 165 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 166 | 12 | ComputeDisplacements(size, all_nnz, all_num_rows, nnz_displs, row_displs, total_nnz); | |
| 167 | } | ||
| 168 | |||
| 169 | 24 | std::vector<double> all_values; | |
| 170 | 24 | std::vector<int> all_col_indices; | |
| 171 |
4/4✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 11 times.
✓ Branch 3 taken 1 times.
|
24 | if (rank == 0 && total_nnz > 0) { |
| 172 |
1/2✓ Branch 1 taken 11 times.
✗ Branch 2 not taken.
|
11 | all_values.resize(total_nnz); |
| 173 |
1/2✓ Branch 1 taken 11 times.
✗ Branch 2 not taken.
|
11 | all_col_indices.resize(total_nnz); |
| 174 | } | ||
| 175 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | MPI_Gatherv(local_result.values.data(), local_nnz, MPI_DOUBLE, all_values.data(), all_nnz.data(), nnz_displs.data(), |
| 176 | MPI_DOUBLE, 0, MPI_COMM_WORLD); | ||
| 177 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | MPI_Gatherv(local_result.col_indices.data(), local_nnz, MPI_INT, all_col_indices.data(), all_nnz.data(), |
| 178 | nnz_displs.data(), MPI_INT, 0, MPI_COMM_WORLD); | ||
| 179 | |||
| 180 |
1/4✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
24 | std::vector<int> local_row_ptr_shifted(my_num_rows); |
| 181 |
2/2✓ Branch 0 taken 43 times.
✓ Branch 1 taken 24 times.
|
67 | for (int i = 0; i < my_num_rows; ++i) { |
| 182 | 43 | local_row_ptr_shifted[i] = local_result.row_ptr[i + 1]; | |
| 183 | } | ||
| 184 | |||
| 185 | 24 | std::vector<int> all_row_ptr_shifted; | |
| 186 |
3/4✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
24 | if (rank == 0 && mat_a_.rows > 0) { |
| 187 |
1/2✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
|
12 | all_row_ptr_shifted.resize(mat_a_.rows); |
| 188 | } | ||
| 189 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | MPI_Gatherv(local_row_ptr_shifted.data(), my_num_rows, MPI_INT, all_row_ptr_shifted.data(), all_num_rows.data(), |
| 190 | row_displs.data(), MPI_INT, 0, MPI_COMM_WORLD); | ||
| 191 | |||
| 192 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 193 |
1/2✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
|
12 | SparseMatrixCRS result(mat_a_.rows, mat_b_.cols); |
| 194 | result.values = std::move(all_values); | ||
| 195 | result.col_indices = std::move(all_col_indices); | ||
| 196 |
1/2✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
|
12 | result.row_ptr.resize(mat_a_.rows + 1); |
| 197 | 12 | BuildResultRowPtr(result, size, all_nnz, all_num_rows, all_row_ptr_shifted); | |
| 198 |
1/2✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
|
12 | GetOutput() = result; |
| 199 | 12 | } | |
| 200 | |||
| 201 | SparseMatrixCRS &output = GetOutput(); | ||
| 202 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | BroadcastSparseMatrix(output, 0); |
| 203 | 24 | } | |
| 204 | |||
| 205 | 24 | bool SparseMatrixMultiplicationMPI::RunImpl() { | |
| 206 | 24 | int rank = 0; | |
| 207 | 24 | int size = 0; | |
| 208 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 209 | 24 | MPI_Comm_size(MPI_COMM_WORLD, &size); | |
| 210 | |||
| 211 | 24 | int total_rows = mat_a_.rows; | |
| 212 | 24 | int base_rows = total_rows / size; | |
| 213 | 24 | int extra_rows = total_rows % size; | |
| 214 | |||
| 215 | int my_start_row = 0; | ||
| 216 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 24 times.
|
36 | for (int rr = 0; rr < rank; ++rr) { |
| 217 |
2/2✓ Branch 0 taken 5 times.
✓ Branch 1 taken 7 times.
|
17 | my_start_row += base_rows + (rr < extra_rows ? 1 : 0); |
| 218 | } | ||
| 219 |
2/2✓ Branch 0 taken 17 times.
✓ Branch 1 taken 7 times.
|
24 | int my_num_rows = base_rows + (rank < extra_rows ? 1 : 0); |
| 220 | |||
| 221 | 24 | SparseMatrixCRS local_result(my_num_rows, mat_b_.cols); | |
| 222 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | local_result.row_ptr.resize(static_cast<std::size_t>(my_num_rows) + 1); |
| 223 |
1/2✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
|
24 | if (!local_result.row_ptr.empty()) { |
| 224 | 24 | local_result.row_ptr[0] = 0; | |
| 225 | } | ||
| 226 | |||
| 227 |
2/2✓ Branch 0 taken 43 times.
✓ Branch 1 taken 24 times.
|
67 | for (int local_i = 0; local_i < my_num_rows; ++local_i) { |
| 228 | 43 | int global_i = my_start_row + local_i; | |
| 229 |
2/2✓ Branch 0 taken 207 times.
✓ Branch 1 taken 43 times.
|
250 | for (int jj = 0; jj < mat_b_.cols; ++jj) { |
| 230 |
2/2✓ Branch 0 taken 72 times.
✓ Branch 1 taken 135 times.
|
207 | double sum = ComputeRowColProduct(mat_a_, mat_b_transposed_, global_i, jj); |
| 231 |
2/2✓ Branch 0 taken 72 times.
✓ Branch 1 taken 135 times.
|
207 | if (std::abs(sum) > 1e-12) { |
| 232 | local_result.values.push_back(sum); | ||
| 233 | local_result.col_indices.push_back(jj); | ||
| 234 | } | ||
| 235 | } | ||
| 236 | 43 | local_result.row_ptr[local_i + 1] = static_cast<int>(local_result.values.size()); | |
| 237 | } | ||
| 238 | |||
| 239 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | GatherResults(local_result, my_num_rows); |
| 240 | 24 | return true; | |
| 241 | 24 | } | |
| 242 | |||
| 243 | 24 | bool SparseMatrixMultiplicationMPI::PostProcessingImpl() { | |
| 244 | 24 | return true; | |
| 245 | } | ||
| 246 | |||
| 247 | } // namespace pikhotskiy_r_multiplication_of_sparse_matrices | ||
| 248 |