| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "posternak_a_crs_mul_complex_matrix/stl/include/ops_stl.hpp" | ||
| 2 | |||
| 3 | #include <algorithm> | ||
| 4 | #include <cmath> | ||
| 5 | #include <complex> | ||
| 6 | #include <cstddef> | ||
| 7 | #include <thread> | ||
| 8 | #include <unordered_map> | ||
| 9 | #include <utility> | ||
| 10 | #include <vector> | ||
| 11 | |||
| 12 | #include "posternak_a_crs_mul_complex_matrix/common/include/common.hpp" | ||
| 13 | |||
| 14 | namespace { | ||
| 15 | |||
| 16 | 80 | size_t ComputeRowNoZeroCount(const posternak_a_crs_mul_complex_matrix::CRSMatrix &a, | |
| 17 | const posternak_a_crs_mul_complex_matrix::CRSMatrix &b, int row, double threshold) { | ||
| 18 | std::unordered_map<int, std::complex<double>> row_sum; | ||
| 19 | |||
| 20 |
2/2✓ Branch 0 taken 96 times.
✓ Branch 1 taken 80 times.
|
176 | for (int idx_a = a.index_row[row]; idx_a < a.index_row[row + 1]; ++idx_a) { |
| 21 | 96 | int col_a = a.index_col[idx_a]; | |
| 22 | 96 | auto val_a = a.values[idx_a]; | |
| 23 | |||
| 24 |
2/2✓ Branch 0 taken 136 times.
✓ Branch 1 taken 96 times.
|
232 | for (int idx_b = b.index_row[col_a]; idx_b < b.index_row[col_a + 1]; ++idx_b) { |
| 25 |
1/2✓ Branch 1 taken 136 times.
✗ Branch 2 not taken.
|
136 | int col_b = b.index_col[idx_b]; |
| 26 |
1/2✓ Branch 1 taken 136 times.
✗ Branch 2 not taken.
|
136 | auto val_b = b.values[idx_b]; |
| 27 | row_sum[col_b] += val_a * val_b; | ||
| 28 | } | ||
| 29 | } | ||
| 30 | |||
| 31 | size_t local = 0; | ||
| 32 |
2/2✓ Branch 0 taken 112 times.
✓ Branch 1 taken 80 times.
|
192 | for (const auto &[col, val] : row_sum) { |
| 33 |
1/2✓ Branch 0 taken 112 times.
✗ Branch 1 not taken.
|
112 | if (std::abs(val) > threshold) { |
| 34 | 112 | ++local; | |
| 35 | } | ||
| 36 | } | ||
| 37 | 80 | return local; | |
| 38 | } | ||
| 39 | |||
| 40 | 32 | void BuildResultStructure(posternak_a_crs_mul_complex_matrix::CRSMatrix &res, std::vector<size_t> &row_prefix) { | |
| 41 |
2/2✓ Branch 0 taken 48 times.
✓ Branch 1 taken 32 times.
|
80 | for (int i = 1; i < res.rows; ++i) { |
| 42 | 48 | row_prefix[i] += row_prefix[i - 1]; | |
| 43 | } | ||
| 44 | |||
| 45 |
1/2✓ Branch 0 taken 32 times.
✗ Branch 1 not taken.
|
32 | const size_t total = row_prefix.empty() ? 0 : row_prefix.back(); |
| 46 | 32 | res.values.resize(total); | |
| 47 | 32 | res.index_col.resize(total); | |
| 48 | 32 | res.index_row.resize(res.rows + 1); | |
| 49 | |||
| 50 |
2/2✓ Branch 0 taken 112 times.
✓ Branch 1 taken 32 times.
|
144 | for (int i = 0; i <= res.rows; ++i) { |
| 51 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 32 times.
|
112 | res.index_row[i] = (i == 0 ? 0 : static_cast<int>(row_prefix[i - 1])); |
| 52 | } | ||
| 53 | 32 | } | |
| 54 | |||
| 55 | 80 | void ComputeAndWriteRow(const posternak_a_crs_mul_complex_matrix::CRSMatrix &a, | |
| 56 | const posternak_a_crs_mul_complex_matrix::CRSMatrix &b, | ||
| 57 | posternak_a_crs_mul_complex_matrix::CRSMatrix &res, int row, double threshold) { | ||
| 58 | std::unordered_map<int, std::complex<double>> row_sum; | ||
| 59 | |||
| 60 |
2/2✓ Branch 0 taken 96 times.
✓ Branch 1 taken 80 times.
|
176 | for (int idx_a = a.index_row[row]; idx_a < a.index_row[row + 1]; ++idx_a) { |
| 61 | 96 | int col_a = a.index_col[idx_a]; | |
| 62 | 96 | auto val_a = a.values[idx_a]; | |
| 63 | |||
| 64 |
2/2✓ Branch 0 taken 136 times.
✓ Branch 1 taken 96 times.
|
232 | for (int idx_b = b.index_row[col_a]; idx_b < b.index_row[col_a + 1]; ++idx_b) { |
| 65 |
1/2✓ Branch 1 taken 136 times.
✗ Branch 2 not taken.
|
136 | int col_b = b.index_col[idx_b]; |
| 66 |
1/2✓ Branch 1 taken 136 times.
✗ Branch 2 not taken.
|
136 | auto val_b = b.values[idx_b]; |
| 67 | row_sum[col_b] += val_a * val_b; | ||
| 68 | } | ||
| 69 | } | ||
| 70 | |||
| 71 |
1/2✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
|
80 | std::vector<std::pair<int, std::complex<double>>> sorted(row_sum.begin(), row_sum.end()); |
| 72 | |||
| 73 |
1/22✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 32 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
|
32 | std::ranges::sort(sorted, [](const auto &p1, const auto &p2) { return p1.first < p2.first; }); |
| 74 | |||
| 75 | 80 | size_t pos = res.index_row[row]; | |
| 76 |
2/2✓ Branch 0 taken 112 times.
✓ Branch 1 taken 80 times.
|
192 | for (const auto &[col_idx, value] : sorted) { |
| 77 |
1/2✓ Branch 0 taken 112 times.
✗ Branch 1 not taken.
|
112 | if (std::abs(value) > threshold) { |
| 78 | 112 | res.values[pos] = value; | |
| 79 | 112 | res.index_col[pos] = col_idx; | |
| 80 | 112 | ++pos; | |
| 81 | } | ||
| 82 | } | ||
| 83 | 80 | } | |
| 84 | |||
| 85 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
|
16 | bool HandleEmptyInput(posternak_a_crs_mul_complex_matrix::CRSMatrix &res) { |
| 86 | res.values.clear(); | ||
| 87 | res.index_col.clear(); | ||
| 88 | 16 | res.index_row.assign(res.rows + 1, 0); | |
| 89 | 16 | return true; | |
| 90 | } | ||
| 91 | |||
| 92 | 32 | std::vector<size_t> CountNonZeroElementsParallel(const posternak_a_crs_mul_complex_matrix::CRSMatrix &a, | |
| 93 | const posternak_a_crs_mul_complex_matrix::CRSMatrix &b, int total_rows, | ||
| 94 | double threshold) { | ||
| 95 | 32 | std::vector<size_t> no_zero_rows(total_rows); | |
| 96 | 32 | unsigned int num_threads = std::thread::hardware_concurrency(); | |
| 97 | 32 | if (num_threads == 0) { | |
| 98 | num_threads = 1; | ||
| 99 | } | ||
| 100 | |||
| 101 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | const unsigned int chunk_size = std::max(1U, (static_cast<unsigned int>(total_rows) + num_threads - 1) / num_threads); |
| 102 | |||
| 103 | 32 | std::vector<std::thread> threads; | |
| 104 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | threads.reserve(num_threads); |
| 105 | |||
| 106 |
2/2✓ Branch 0 taken 104 times.
✓ Branch 1 taken 8 times.
|
112 | for (unsigned int thr = 0; thr < num_threads; ++thr) { |
| 107 | 104 | const int start_row = static_cast<int>(thr * chunk_size); | |
| 108 | const int end_row = static_cast<int>( | ||
| 109 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 24 times.
|
104 | std::min(static_cast<unsigned int>(start_row) + chunk_size, static_cast<unsigned int>(total_rows))); |
| 110 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 24 times.
|
104 | if (start_row >= total_rows) { |
| 111 | break; | ||
| 112 | } | ||
| 113 | |||
| 114 |
1/2✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
|
80 | threads.emplace_back([&, start_row, end_row]() { |
| 115 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 80 times.
|
160 | for (int row = start_row; row < end_row; ++row) { |
| 116 | 80 | no_zero_rows[row] = ComputeRowNoZeroCount(a, b, row, threshold); | |
| 117 | } | ||
| 118 | 80 | }); | |
| 119 | } | ||
| 120 | |||
| 121 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 32 times.
|
112 | for (auto &thread : threads) { |
| 122 |
1/2✓ Branch 0 taken 80 times.
✗ Branch 1 not taken.
|
80 | if (thread.joinable()) { |
| 123 |
1/2✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
|
80 | thread.join(); |
| 124 | } | ||
| 125 | } | ||
| 126 | |||
| 127 | 32 | return no_zero_rows; | |
| 128 | 32 | } | |
| 129 | |||
| 130 | 32 | void ComputeResultValuesParallel(const posternak_a_crs_mul_complex_matrix::CRSMatrix &a, | |
| 131 | const posternak_a_crs_mul_complex_matrix::CRSMatrix &b, | ||
| 132 | posternak_a_crs_mul_complex_matrix::CRSMatrix &res, int total_rows, double threshold) { | ||
| 133 | 32 | unsigned int num_threads = std::thread::hardware_concurrency(); | |
| 134 | 32 | if (num_threads == 0) { | |
| 135 | num_threads = 1; | ||
| 136 | } | ||
| 137 | |||
| 138 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | const unsigned int chunk_size = std::max(1U, (static_cast<unsigned int>(total_rows) + num_threads - 1) / num_threads); |
| 139 | |||
| 140 | 32 | std::vector<std::thread> threads; | |
| 141 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | threads.reserve(num_threads); |
| 142 | |||
| 143 |
2/2✓ Branch 0 taken 104 times.
✓ Branch 1 taken 8 times.
|
112 | for (unsigned int thr = 0; thr < num_threads; ++thr) { |
| 144 | 104 | const int start_row = static_cast<int>(thr * chunk_size); | |
| 145 | const int end_row = static_cast<int>( | ||
| 146 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 24 times.
|
104 | std::min(static_cast<unsigned int>(start_row) + chunk_size, static_cast<unsigned int>(total_rows))); |
| 147 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 24 times.
|
104 | if (start_row >= total_rows) { |
| 148 | break; | ||
| 149 | } | ||
| 150 | |||
| 151 |
1/2✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
|
80 | threads.emplace_back([&, start_row, end_row]() { |
| 152 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 80 times.
|
160 | for (int row = start_row; row < end_row; ++row) { |
| 153 | 80 | ComputeAndWriteRow(a, b, res, row, threshold); | |
| 154 | } | ||
| 155 | 80 | }); | |
| 156 | } | ||
| 157 | |||
| 158 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 32 times.
|
112 | for (auto &thread : threads) { |
| 159 |
1/2✓ Branch 0 taken 80 times.
✗ Branch 1 not taken.
|
80 | if (thread.joinable()) { |
| 160 |
1/2✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
|
80 | thread.join(); |
| 161 | } | ||
| 162 | } | ||
| 163 | 32 | } | |
| 164 | |||
| 165 | } // namespace | ||
| 166 | |||
| 167 | namespace posternak_a_crs_mul_complex_matrix { | ||
| 168 | |||
| 169 |
1/2✓ Branch 2 taken 48 times.
✗ Branch 3 not taken.
|
48 | PosternakACRSMulComplexMatrixSTL::PosternakACRSMulComplexMatrixSTL(const InType &in) { |
| 170 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 171 | GetInput() = in; | ||
| 172 | 48 | GetOutput() = CRSMatrix{}; | |
| 173 | 48 | } | |
| 174 | |||
| 175 | 48 | bool PosternakACRSMulComplexMatrixSTL::ValidationImpl() { | |
| 176 | const auto &input = GetInput(); | ||
| 177 | 48 | const auto &a = input.first; | |
| 178 | 48 | const auto &b = input.second; | |
| 179 |
3/6✓ Branch 0 taken 48 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 48 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 48 times.
|
48 | return a.IsValid() && b.IsValid() && a.cols == b.rows; |
| 180 | } | ||
| 181 | |||
| 182 | 48 | bool PosternakACRSMulComplexMatrixSTL::PreProcessingImpl() { | |
| 183 | const auto &input = GetInput(); | ||
| 184 | const auto &a = input.first; | ||
| 185 | const auto &b = input.second; | ||
| 186 | auto &res = GetOutput(); | ||
| 187 | |||
| 188 | 48 | res.rows = a.rows; | |
| 189 | 48 | res.cols = b.cols; | |
| 190 | 48 | return true; | |
| 191 | } | ||
| 192 | |||
| 193 | 48 | bool PosternakACRSMulComplexMatrixSTL::RunImpl() { | |
| 194 | const auto &input = GetInput(); | ||
| 195 | 48 | const auto &a = input.first; | |
| 196 |
2/2✓ Branch 0 taken 40 times.
✓ Branch 1 taken 8 times.
|
48 | const auto &b = input.second; |
| 197 | auto &res = GetOutput(); | ||
| 198 | |||
| 199 |
4/4✓ Branch 0 taken 40 times.
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 32 times.
|
48 | if (a.values.empty() || b.values.empty()) { |
| 200 | 16 | return HandleEmptyInput(res); | |
| 201 | } | ||
| 202 | |||
| 203 | constexpr double kThreshold = 1e-12; | ||
| 204 | 32 | res.rows = a.rows; | |
| 205 | 32 | res.cols = b.cols; | |
| 206 | |||
| 207 | 32 | std::vector<size_t> no_zero_rows = CountNonZeroElementsParallel(a, b, res.rows, kThreshold); | |
| 208 | |||
| 209 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | BuildResultStructure(res, no_zero_rows); |
| 210 | |||
| 211 |
1/2✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
|
32 | ComputeResultValuesParallel(a, b, res, res.rows, kThreshold); |
| 212 | |||
| 213 |
1/2✓ Branch 0 taken 32 times.
✗ Branch 1 not taken.
|
32 | return res.IsValid(); |
| 214 | } | ||
| 215 | |||
| 216 | 48 | bool PosternakACRSMulComplexMatrixSTL::PostProcessingImpl() { | |
| 217 | 48 | return GetOutput().IsValid(); | |
| 218 | } | ||
| 219 | |||
| 220 | } // namespace posternak_a_crs_mul_complex_matrix | ||
| 221 |