| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "liulin_y_complex_ccs/tbb/include/ops_tbb.hpp" | ||
| 2 | |||
| 3 | #include <tbb/tbb.h> | ||
| 4 | |||
| 5 | #include <atomic> | ||
| 6 | #include <cmath> | ||
| 7 | #include <complex> | ||
| 8 | #include <cstddef> | ||
| 9 | #include <iterator> | ||
| 10 | #include <utility> | ||
| 11 | #include <vector> | ||
| 12 | |||
| 13 | #include "liulin_y_complex_ccs/common/include/common.hpp" | ||
| 14 | |||
| 15 | namespace liulin_y_complex_ccs { | ||
| 16 | |||
| 17 | namespace { | ||
| 18 | |||
| 19 | constexpr double kEpsilon = 1e-10; | ||
| 20 | |||
| 21 | 48 | bool IsValidCCS(const CCSMatrix &mat) { | |
| 22 |
2/4✓ Branch 0 taken 48 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 48 times.
✗ Branch 3 not taken.
|
48 | if (mat.count_rows <= 0 || mat.count_cols <= 0) { |
| 23 | return false; | ||
| 24 | } | ||
| 25 |
1/2✓ Branch 0 taken 48 times.
✗ Branch 1 not taken.
|
48 | if (mat.col_index.size() != static_cast<size_t>(mat.count_cols) + 1) { |
| 26 | return false; | ||
| 27 | } | ||
| 28 |
1/2✓ Branch 0 taken 48 times.
✗ Branch 1 not taken.
|
48 | if (mat.col_index[0] != 0) { |
| 29 | return false; | ||
| 30 | } | ||
| 31 |
1/2✓ Branch 0 taken 48 times.
✗ Branch 1 not taken.
|
48 | if (mat.values.size() != mat.row_index.size()) { |
| 32 | return false; | ||
| 33 | } | ||
| 34 |
1/2✓ Branch 0 taken 48 times.
✗ Branch 1 not taken.
|
48 | if (std::cmp_not_equal(mat.col_index.back(), mat.values.size())) { |
| 35 | ✗ | return false; | |
| 36 | } | ||
| 37 | return true; | ||
| 38 | } | ||
| 39 | |||
| 40 | 24 | void TransposeCCS(const CCSMatrix &mat, CCSMatrix &mat_t) { | |
| 41 | 24 | mat_t.count_rows = mat.count_cols; | |
| 42 | 24 | mat_t.count_cols = mat.count_rows; | |
| 43 | const size_t nnz = mat.values.size(); | ||
| 44 | 24 | mat_t.values.resize(nnz); | |
| 45 | 24 | mat_t.row_index.resize(nnz); | |
| 46 | 24 | mat_t.col_index.assign(static_cast<size_t>(mat_t.count_cols) + 1, 0); | |
| 47 | |||
| 48 |
2/2✓ Branch 0 taken 60 times.
✓ Branch 1 taken 24 times.
|
84 | for (size_t i = 0; i < nnz; ++i) { |
| 49 | 60 | mat_t.col_index[static_cast<size_t>(mat.row_index[i]) + 1]++; | |
| 50 | } | ||
| 51 |
2/2✓ Branch 0 taken 48 times.
✓ Branch 1 taken 24 times.
|
72 | for (int i = 0; i < mat_t.count_cols; ++i) { |
| 52 | 48 | mat_t.col_index[static_cast<size_t>(i) + 1] += mat_t.col_index[static_cast<size_t>(i)]; | |
| 53 | } | ||
| 54 | |||
| 55 | 24 | std::vector<int> current_pos(mat_t.col_index.begin(), mat_t.col_index.end()); | |
| 56 |
2/2✓ Branch 0 taken 52 times.
✓ Branch 1 taken 24 times.
|
76 | for (int j = 0; j < mat.count_cols; ++j) { |
| 57 |
2/2✓ Branch 0 taken 60 times.
✓ Branch 1 taken 52 times.
|
112 | for (int k = mat.col_index[static_cast<size_t>(j)]; k < mat.col_index[static_cast<size_t>(j) + 1]; ++k) { |
| 58 | 60 | const int row = mat.row_index[static_cast<size_t>(k)]; | |
| 59 | 60 | const int dest = current_pos[static_cast<size_t>(row)]++; | |
| 60 | 60 | mat_t.row_index[static_cast<size_t>(dest)] = j; | |
| 61 | 60 | mat_t.values[static_cast<size_t>(dest)] = mat.values[static_cast<size_t>(k)]; | |
| 62 | } | ||
| 63 | } | ||
| 64 | 24 | } | |
| 65 | |||
| 66 | 76 | std::complex<double> ComputeDotProduct(int a_start, int a_end, int b_start, int b_end, const CCSMatrix &mat_at, | |
| 67 | const CCSMatrix &mat_b) { | ||
| 68 | int ptr_a = a_start; | ||
| 69 | int ptr_b = b_start; | ||
| 70 | std::complex<double> sum(0.0, 0.0); | ||
| 71 | |||
| 72 |
2/2✓ Branch 0 taken 88 times.
✓ Branch 1 taken 76 times.
|
164 | while (ptr_a < a_end && ptr_b < b_end) { |
| 73 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 64 times.
|
88 | const int idx_a = mat_at.row_index[static_cast<size_t>(ptr_a)]; |
| 74 | 88 | const int idx_b = mat_b.row_index[static_cast<size_t>(ptr_b)]; | |
| 75 | |||
| 76 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 64 times.
|
88 | if (idx_a < idx_b) { |
| 77 | 24 | ptr_a++; | |
| 78 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 40 times.
|
64 | } else if (idx_a > idx_b) { |
| 79 | 24 | ptr_b++; | |
| 80 | } else { | ||
| 81 | sum += mat_at.values[static_cast<size_t>(ptr_a)] * mat_b.values[static_cast<size_t>(ptr_b)]; | ||
| 82 | 40 | ptr_a++; | |
| 83 | 40 | ptr_b++; | |
| 84 | } | ||
| 85 | } | ||
| 86 | 76 | return sum; | |
| 87 | } | ||
| 88 | |||
| 89 | 48 | void ProcessColumn(int j, int res_rows, const CCSMatrix &mat_at, const CCSMatrix &mat_b, | |
| 90 | std::vector<std::complex<double>> &col_values, std::vector<int> &col_rows) { | ||
| 91 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 12 times.
|
48 | const int b_start = mat_b.col_index[static_cast<size_t>(j)]; |
| 92 | 48 | const int b_end = mat_b.col_index[static_cast<size_t>(j) + 1]; | |
| 93 | |||
| 94 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 12 times.
|
48 | if (b_start == b_end) { |
| 95 | return; | ||
| 96 | } | ||
| 97 | |||
| 98 |
2/2✓ Branch 0 taken 76 times.
✓ Branch 1 taken 36 times.
|
112 | for (int i = 0; i < res_rows; ++i) { |
| 99 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 76 times.
|
76 | const int a_start = mat_at.col_index[static_cast<size_t>(i)]; |
| 100 | 76 | const int a_end = mat_at.col_index[static_cast<size_t>(i) + 1]; | |
| 101 | |||
| 102 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 76 times.
|
76 | if (a_start == a_end) { |
| 103 | ✗ | continue; | |
| 104 | } | ||
| 105 | |||
| 106 |
2/2✓ Branch 0 taken 40 times.
✓ Branch 1 taken 36 times.
|
76 | std::complex<double> sum = ComputeDotProduct(a_start, a_end, b_start, b_end, mat_at, mat_b); |
| 107 | |||
| 108 |
3/4✓ Branch 0 taken 40 times.
✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 40 times.
|
76 | if (std::abs(sum.real()) > kEpsilon || std::abs(sum.imag()) > kEpsilon) { |
| 109 | col_values.push_back(sum); | ||
| 110 | col_rows.push_back(i); | ||
| 111 | } | ||
| 112 | } | ||
| 113 | } | ||
| 114 | |||
| 115 | } // namespace | ||
| 116 | |||
| 117 |
1/2✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
|
24 | LiulinYComplexCcsTbb::LiulinYComplexCcsTbb(const InType &in) : BaseTask() { |
| 118 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 119 | GetInput() = in; | ||
| 120 | 24 | } | |
| 121 | |||
| 122 | 24 | bool LiulinYComplexCcsTbb::ValidationImpl() { | |
| 123 | 24 | const auto &mat_a = GetInput().first; | |
| 124 | 24 | const auto &mat_b = GetInput().second; | |
| 125 | |||
| 126 |
2/4✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
|
24 | if (!IsValidCCS(mat_a) || !IsValidCCS(mat_b)) { |
| 127 | return false; | ||
| 128 | } | ||
| 129 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
|
24 | if (mat_a.count_cols != mat_b.count_rows) { |
| 130 | ✗ | return false; | |
| 131 | } | ||
| 132 | |||
| 133 | return true; | ||
| 134 | } | ||
| 135 | |||
| 136 | 24 | bool LiulinYComplexCcsTbb::PreProcessingImpl() { | |
| 137 | 24 | return true; | |
| 138 | } | ||
| 139 | |||
| 140 | 24 | bool LiulinYComplexCcsTbb::RunImpl() { | |
| 141 | 24 | const auto &mat_a = GetInput().first; | |
| 142 | 24 | const auto &mat_b = GetInput().second; | |
| 143 | auto &mat_res = GetOutput(); | ||
| 144 | |||
| 145 | 24 | const int res_rows = mat_a.count_rows; | |
| 146 | 24 | const int res_cols = mat_b.count_cols; | |
| 147 | |||
| 148 | 24 | CCSMatrix mat_at; | |
| 149 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | TransposeCCS(mat_a, mat_at); |
| 150 | |||
| 151 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | std::vector<std::vector<std::complex<double>>> thread_values(static_cast<size_t>(res_cols)); |
| 152 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | std::vector<std::vector<int>> thread_row_indices(static_cast<size_t>(res_cols)); |
| 153 | 24 | std::atomic<bool> success{true}; | |
| 154 | |||
| 155 | try { | ||
| 156 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
72 | tbb::parallel_for(tbb::blocked_range<int>(0, res_cols), [&](const tbb::blocked_range<int> &r) { |
| 157 |
2/2✓ Branch 0 taken 48 times.
✓ Branch 1 taken 48 times.
|
96 | for (int j = r.begin(); j != r.end(); ++j) { |
| 158 | 48 | ProcessColumn(j, res_rows, mat_at, mat_b, thread_values[static_cast<size_t>(j)], | |
| 159 | 48 | thread_row_indices[static_cast<size_t>(j)]); | |
| 160 | } | ||
| 161 | 48 | }); | |
| 162 | ✗ | } catch (...) { | |
| 163 | success = false; | ||
| 164 | ✗ | } | |
| 165 | |||
| 166 |
1/2✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
|
24 | if (!success) { |
| 167 | return false; | ||
| 168 | } | ||
| 169 | |||
| 170 | 24 | mat_res.count_rows = res_rows; | |
| 171 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
|
24 | mat_res.count_cols = res_cols; |
| 172 | mat_res.values.clear(); | ||
| 173 | mat_res.row_index.clear(); | ||
| 174 |
1/2✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
|
24 | mat_res.col_index.assign(static_cast<size_t>(res_cols) + 1, 0); |
| 175 | |||
| 176 |
2/2✓ Branch 0 taken 48 times.
✓ Branch 1 taken 24 times.
|
72 | for (int j = 0; j < res_cols; ++j) { |
| 177 | 48 | const auto u_j = static_cast<size_t>(j); | |
| 178 |
1/2✓ Branch 1 taken 48 times.
✗ Branch 2 not taken.
|
48 | mat_res.values.insert(mat_res.values.end(), std::make_move_iterator(thread_values[u_j].begin()), |
| 179 | std::make_move_iterator(thread_values[u_j].end())); | ||
| 180 |
1/2✓ Branch 1 taken 48 times.
✗ Branch 2 not taken.
|
48 | mat_res.row_index.insert(mat_res.row_index.end(), std::make_move_iterator(thread_row_indices[u_j].begin()), |
| 181 | std::make_move_iterator(thread_row_indices[u_j].end())); | ||
| 182 | 48 | mat_res.col_index[u_j + 1] = static_cast<int>(mat_res.values.size()); | |
| 183 | } | ||
| 184 | |||
| 185 | return true; | ||
| 186 | 24 | } | |
| 187 | |||
| 188 | 24 | bool LiulinYComplexCcsTbb::PostProcessingImpl() { | |
| 189 | const auto &mat_res = GetOutput(); | ||
| 190 |
2/4✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
|
24 | if (mat_res.count_rows <= 0 || mat_res.count_cols <= 0) { |
| 191 | return false; | ||
| 192 | } | ||
| 193 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
|
24 | if (mat_res.col_index.size() != static_cast<size_t>(mat_res.count_cols) + 1) { |
| 194 | ✗ | return false; | |
| 195 | } | ||
| 196 | return true; | ||
| 197 | } | ||
| 198 | |||
| 199 | } // namespace liulin_y_complex_ccs | ||
| 200 |