| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "kazennova_a_fox_algorithm/all/include/ops_all.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <atomic> | ||
| 7 | #include <cstddef> | ||
| 8 | #include <thread> | ||
| 9 | #include <vector> | ||
| 10 | |||
| 11 | #include "kazennova_a_fox_algorithm/common/include/common.hpp" | ||
| 12 | |||
| 13 | namespace kazennova_a_fox_algorithm { | ||
| 14 | |||
| 15 | namespace { | ||
| 16 | |||
| 17 | 16 | void GetBlock(const std::vector<double> &mat, int rows, int cols, int block_row, int block_col, int block_size, | |
| 18 | double *block_buf) { | ||
| 19 | 16 | const int start_row = block_row * block_size; | |
| 20 | 16 | const int start_col = block_col * block_size; | |
| 21 | 16 | const int end_row = std::min(start_row + block_size, rows); | |
| 22 | 16 | const int end_col = std::min(start_col + block_size, cols); | |
| 23 | |||
| 24 |
2/2✓ Branch 0 taken 1024 times.
✓ Branch 1 taken 16 times.
|
1040 | for (int i = 0; i < block_size; ++i) { |
| 25 |
2/2✓ Branch 0 taken 65536 times.
✓ Branch 1 taken 1024 times.
|
66560 | for (int j = 0; j < block_size; ++j) { |
| 26 | 65536 | block_buf[(i * block_size) + j] = 0.0; | |
| 27 | } | ||
| 28 | } | ||
| 29 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 16 times.
|
96 | for (int i = start_row; i < end_row; ++i) { |
| 30 |
2/2✓ Branch 0 taken 552 times.
✓ Branch 1 taken 80 times.
|
632 | for (int j = start_col; j < end_col; ++j) { |
| 31 | 552 | block_buf[((i - start_row) * block_size) + (j - start_col)] = mat[(i * cols) + j]; | |
| 32 | } | ||
| 33 | } | ||
| 34 | 16 | } | |
| 35 | |||
| 36 | 8 | void MultiplyBlock(const std::vector<double> &block_a, const std::vector<double> &block_b, int block_size, int max_i, | |
| 37 | int max_j, int max_k, int bi, int bj, int n, std::vector<double> &local_c) { | ||
| 38 |
2/2✓ Branch 0 taken 11 times.
✓ Branch 1 taken 8 times.
|
19 | for (int i = 0; i < max_i; ++i) { |
| 39 | 11 | const int local_row = (bi * block_size) + i; | |
| 40 |
2/2✓ Branch 0 taken 73 times.
✓ Branch 1 taken 11 times.
|
84 | for (int j = 0; j < max_j; ++j) { |
| 41 | double sum = 0.0; | ||
| 42 |
2/2✓ Branch 0 taken 597 times.
✓ Branch 1 taken 73 times.
|
670 | for (int kk = 0; kk < max_k; ++kk) { |
| 43 | 597 | sum += block_a[(i * block_size) + kk] * block_b[(kk * block_size) + j]; | |
| 44 | } | ||
| 45 | 73 | local_c[(local_row * n) + ((bj * block_size) + j)] += sum; | |
| 46 | } | ||
| 47 | } | ||
| 48 | 8 | } | |
| 49 | |||
| 50 | 8 | void RunLocalMultiplication(const std::vector<double> &a, const std::vector<double> &b, int rows_total, int cols_a, | |
| 51 | int cols_b, int start_row, int local_rows, int block_size, std::vector<double> &local_c) { | ||
| 52 | 8 | const int blocks_i_local = (local_rows + block_size - 1) / block_size; | |
| 53 | 8 | const int blocks_j = (cols_b + block_size - 1) / block_size; | |
| 54 | 8 | const int blocks_k = (cols_a + block_size - 1) / block_size; | |
| 55 | |||
| 56 | 8 | int num_threads = static_cast<int>(std::thread::hardware_concurrency()); | |
| 57 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if (num_threads <= 0) { |
| 58 | num_threads = 2; | ||
| 59 | } | ||
| 60 | |||
| 61 | 8 | std::vector<std::thread> threads; | |
| 62 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | threads.reserve(static_cast<size_t>(num_threads)); |
| 63 | |||
| 64 | 8 | std::atomic<size_t> next_block_idx(0); | |
| 65 | 8 | const size_t total_blocks = static_cast<size_t>(blocks_i_local) * blocks_j; | |
| 66 | |||
| 67 | 32 | auto worker = [&]() { | |
| 68 | 32 | std::vector<double> block_a(static_cast<size_t>(block_size) * block_size); | |
| 69 |
1/4✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
32 | std::vector<double> block_b(static_cast<size_t>(block_size) * block_size); |
| 70 | |||
| 71 | while (true) { | ||
| 72 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 32 times.
|
40 | const size_t idx = next_block_idx.fetch_add(1); |
| 73 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 32 times.
|
40 | if (idx >= total_blocks) { |
| 74 | break; | ||
| 75 | } | ||
| 76 | |||
| 77 | 8 | const int bi = static_cast<int>(idx / blocks_j); | |
| 78 | 8 | const int bj = static_cast<int>(idx % blocks_j); | |
| 79 | |||
| 80 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
|
16 | for (int bk = 0; bk < blocks_k; ++bk) { |
| 81 | 8 | const int bi_global = (start_row / block_size) + bi; | |
| 82 | 8 | GetBlock(a, rows_total, cols_a, bi_global, bk, block_size, block_a.data()); | |
| 83 | 8 | GetBlock(b, cols_a, cols_b, bk, bj, block_size, block_b.data()); | |
| 84 | |||
| 85 | 8 | const int offset = start_row % block_size; | |
| 86 |
1/2✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
|
8 | const int max_i = std::min(block_size, local_rows - (bi * block_size) - offset); |
| 87 |
1/2✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
|
8 | const int max_j = std::min(block_size, cols_b - (bj * block_size)); |
| 88 |
1/2✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
|
8 | const int max_k = std::min(block_size, cols_a - (bk * block_size)); |
| 89 | |||
| 90 | 8 | MultiplyBlock(block_a, block_b, block_size, max_i, max_j, max_k, bi, bj, cols_b, local_c); | |
| 91 | } | ||
| 92 | } | ||
| 93 | 32 | }; | |
| 94 | |||
| 95 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 8 times.
|
40 | for (int thread_idx = 0; thread_idx < num_threads; ++thread_idx) { |
| 96 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | threads.emplace_back(worker); |
| 97 | } | ||
| 98 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 8 times.
|
40 | for (auto &thr : threads) { |
| 99 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | thr.join(); |
| 100 | } | ||
| 101 | 8 | } | |
| 102 | |||
| 103 | void ComputeRecvCounts(int world_size, int rows_per_proc, int remainder, int cols_b, std::vector<int> &recv_counts, | ||
| 104 | std::vector<int> &displs, int &total_elements) { | ||
| 105 | total_elements = 0; | ||
| 106 |
4/4✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 4 times.
|
24 | for (int proc = 0; proc < world_size; ++proc) { |
| 107 |
4/4✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 6 times.
✓ Branch 3 taken 2 times.
|
16 | const int proc_local_rows = rows_per_proc + (proc < remainder ? 1 : 0); |
| 108 | 16 | recv_counts[proc] = proc_local_rows * cols_b; | |
| 109 | 16 | displs[proc] = total_elements; | |
| 110 | 16 | total_elements += recv_counts[proc]; | |
| 111 | } | ||
| 112 | } | ||
| 113 | |||
| 114 | 8 | void GatherAndAssemble(int rank, int world_size, int rows_per_proc, int remainder, int cols_b, | |
| 115 | const std::vector<double> &local_c, std::vector<double> &c) { | ||
| 116 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
|
8 | if (rank == 0) { |
| 117 | 4 | std::vector<int> recv_counts(world_size); | |
| 118 |
1/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
4 | std::vector<int> displs(world_size); |
| 119 | int total_elements = 0; | ||
| 120 | 4 | ComputeRecvCounts(world_size, rows_per_proc, remainder, cols_b, recv_counts, displs, total_elements); | |
| 121 |
2/6✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
4 | std::vector<double> gathered(static_cast<size_t>(total_elements)); |
| 122 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | MPI_Gatherv(local_c.data(), static_cast<int>(local_c.size()), MPI_DOUBLE, gathered.data(), recv_counts.data(), |
| 123 | displs.data(), MPI_DOUBLE, 0, MPI_COMM_WORLD); | ||
| 124 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
|
12 | for (int proc = 0; proc < world_size; ++proc) { |
| 125 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
|
8 | const int proc_start_row = (proc * rows_per_proc) + std::min(proc, remainder); |
| 126 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
|
8 | const int proc_local_rows = rows_per_proc + (proc < remainder ? 1 : 0); |
| 127 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 8 times.
|
28 | for (int i = 0; i < proc_local_rows; ++i) { |
| 128 |
2/2✓ Branch 0 taken 138 times.
✓ Branch 1 taken 20 times.
|
158 | for (int j = 0; j < cols_b; ++j) { |
| 129 | 138 | c[((proc_start_row + i) * cols_b) + j] = gathered[displs[proc] + (i * cols_b) + j]; | |
| 130 | } | ||
| 131 | } | ||
| 132 | } | ||
| 133 | } else { | ||
| 134 | 4 | std::vector<int> recv_counts(world_size); | |
| 135 |
1/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
4 | std::vector<int> displs(world_size); |
| 136 | int total_elements = 0; | ||
| 137 | 4 | ComputeRecvCounts(world_size, rows_per_proc, remainder, cols_b, recv_counts, displs, total_elements); | |
| 138 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | MPI_Gatherv(local_c.data(), static_cast<int>(local_c.size()), MPI_DOUBLE, nullptr, nullptr, nullptr, MPI_DOUBLE, 0, |
| 139 | MPI_COMM_WORLD); | ||
| 140 | } | ||
| 141 | 8 | } | |
| 142 | |||
| 143 | } // namespace | ||
| 144 | |||
| 145 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | KazennovaATestTaskALL::KazennovaATestTaskALL(const InType &in) { |
| 146 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 147 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | GetInput() = in; |
| 148 | 8 | } | |
| 149 | |||
| 150 |
1/2✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
|
8 | bool KazennovaATestTaskALL::ValidationImpl() { |
| 151 | const auto &in = GetInput(); | ||
| 152 |
2/4✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
8 | if (in.A.data.empty() || in.B.data.empty()) { |
| 153 | return false; | ||
| 154 | } | ||
| 155 |
4/8✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
|
8 | if (in.A.rows <= 0 || in.A.cols <= 0 || in.B.rows <= 0 || in.B.cols <= 0) { |
| 156 | return false; | ||
| 157 | } | ||
| 158 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if (in.A.cols != in.B.rows) { |
| 159 | ✗ | return false; | |
| 160 | } | ||
| 161 | return true; | ||
| 162 | } | ||
| 163 | |||
| 164 | 8 | bool KazennovaATestTaskALL::PreProcessingImpl() { | |
| 165 | const auto &in = GetInput(); | ||
| 166 | auto &out = GetOutput(); | ||
| 167 | 8 | out.rows = in.A.rows; | |
| 168 | 8 | out.cols = in.B.cols; | |
| 169 | 8 | out.data.assign(static_cast<size_t>(out.rows) * out.cols, 0.0); | |
| 170 | 8 | return true; | |
| 171 | } | ||
| 172 | |||
| 173 | 8 | bool KazennovaATestTaskALL::RunImpl() { | |
| 174 | 8 | int rank = -1; | |
| 175 | 8 | int world_size = 1; | |
| 176 | 8 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 177 | 8 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); | |
| 178 | |||
| 179 | const auto &in = GetInput(); | ||
| 180 | auto &out = GetOutput(); | ||
| 181 | |||
| 182 | 8 | const int rows_total = in.A.rows; | |
| 183 | 8 | const int cols_a = in.A.cols; | |
| 184 | 8 | const int cols_b = in.B.cols; | |
| 185 | 8 | const auto &a = in.A.data; | |
| 186 | 8 | const auto &b = in.B.data; | |
| 187 | 8 | auto &c = out.data; | |
| 188 | |||
| 189 | const int block_size = kBlockSize; | ||
| 190 | |||
| 191 | 8 | const int rows_per_proc = rows_total / world_size; | |
| 192 | 8 | const int remainder = rows_total % world_size; | |
| 193 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
|
8 | const int start_row = (rank * rows_per_proc) + std::min(rank, remainder); |
| 194 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
|
8 | const int local_rows = rows_per_proc + (rank < remainder ? 1 : 0); |
| 195 | |||
| 196 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if (local_rows == 0) { |
| 197 | ✗ | if (rank == 0) { | |
| 198 | return true; | ||
| 199 | } | ||
| 200 | ✗ | MPI_Barrier(MPI_COMM_WORLD); | |
| 201 | ✗ | return true; | |
| 202 | } | ||
| 203 | |||
| 204 | 8 | std::vector<double> local_c(static_cast<size_t>(local_rows) * cols_b, 0.0); | |
| 205 | |||
| 206 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | RunLocalMultiplication(a, b, rows_total, cols_a, cols_b, start_row, local_rows, block_size, local_c); |
| 207 | |||
| 208 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | GatherAndAssemble(rank, world_size, rows_per_proc, remainder, cols_b, local_c, c); |
| 209 | |||
| 210 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | MPI_Barrier(MPI_COMM_WORLD); |
| 211 | return true; | ||
| 212 | } | ||
| 213 | |||
| 214 | 8 | bool KazennovaATestTaskALL::PostProcessingImpl() { | |
| 215 | 8 | return !GetOutput().data.empty(); | |
| 216 | } | ||
| 217 | |||
| 218 | } // namespace kazennova_a_fox_algorithm | ||
| 219 |