| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "ermakov_a_spar_mat_mult/tbb/include/ops_tbb.hpp" | ||
| 2 | |||
| 3 | #include <algorithm> | ||
| 4 | #include <complex> | ||
| 5 | #include <cstddef> | ||
| 6 | #include <vector> | ||
| 7 | |||
| 8 | #include "ermakov_a_spar_mat_mult/common/include/common.hpp" | ||
| 9 | #include "oneapi/tbb/blocked_range.h" | ||
| 10 | #include "oneapi/tbb/enumerable_thread_specific.h" | ||
| 11 | #include "oneapi/tbb/parallel_for.h" | ||
| 12 | |||
| 13 | namespace ermakov_a_spar_mat_mult { | ||
| 14 | |||
| 15 | namespace { | ||
| 16 | |||
| 17 | struct RowWorkspace { | ||
| 18 | std::vector<std::complex<double>> row_vals; | ||
| 19 | std::vector<int> row_mark; | ||
| 20 | std::vector<int> used_cols; | ||
| 21 | |||
| 22 | 28 | explicit RowWorkspace(int cols) | |
| 23 | 28 | : row_vals(static_cast<std::size_t>(cols), std::complex<double>(0.0, 0.0)), | |
| 24 |
2/6✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
28 | row_mark(static_cast<std::size_t>(cols), -1) { |
| 25 |
1/2✓ Branch 1 taken 28 times.
✗ Branch 2 not taken.
|
28 | used_cols.reserve(256); |
| 26 | 28 | } | |
| 27 | }; | ||
| 28 | |||
| 29 | int ResolveGrainSize(int rows) { | ||
| 30 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
|
16 | if (rows <= 0) { |
| 31 | return 1; | ||
| 32 | } | ||
| 33 | |||
| 34 | constexpr int kTargetChunks = 16; | ||
| 35 |
1/2✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
|
16 | return std::max(1, rows / kTargetChunks); |
| 36 | } | ||
| 37 | |||
| 38 | } // namespace | ||
| 39 | |||
| 40 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | ErmakovASparMatMultTBB::ErmakovASparMatMultTBB(const InType &in) { |
| 41 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 42 | GetInput() = in; | ||
| 43 | 16 | } | |
| 44 | |||
| 45 | 32 | bool ErmakovASparMatMultTBB::ValidateMatrix(const MatrixCRS &m) { | |
| 46 |
2/4✓ Branch 0 taken 32 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 32 times.
|
32 | if (m.rows < 0 || m.cols < 0) { |
| 47 | return false; | ||
| 48 | } | ||
| 49 | |||
| 50 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
|
32 | if (m.row_ptr.size() != static_cast<std::size_t>(m.rows) + 1) { |
| 51 | return false; | ||
| 52 | } | ||
| 53 | |||
| 54 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
|
32 | if (m.values.size() != m.col_index.size()) { |
| 55 | return false; | ||
| 56 | } | ||
| 57 | |||
| 58 | 32 | const int nnz = static_cast<int>(m.values.size()); | |
| 59 | |||
| 60 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
|
32 | if (m.row_ptr.empty()) { |
| 61 | return false; | ||
| 62 | } | ||
| 63 | |||
| 64 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 32 times.
|
32 | if (m.row_ptr.front() != 0 || m.row_ptr.back() != nnz) { |
| 65 | return false; | ||
| 66 | } | ||
| 67 | |||
| 68 |
2/2✓ Branch 0 taken 504 times.
✓ Branch 1 taken 32 times.
|
536 | for (int i = 0; i < m.rows; ++i) { |
| 69 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 504 times.
|
504 | if (m.row_ptr[i] > m.row_ptr[i + 1]) { |
| 70 | return false; | ||
| 71 | } | ||
| 72 | } | ||
| 73 | |||
| 74 |
2/2✓ Branch 0 taken 5692 times.
✓ Branch 1 taken 32 times.
|
5724 | for (int k = 0; k < nnz; ++k) { |
| 75 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 5692 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 5692 times.
|
5692 | if (m.col_index[k] < 0 || m.col_index[k] >= m.cols) { |
| 76 | return false; | ||
| 77 | } | ||
| 78 | } | ||
| 79 | |||
| 80 | return true; | ||
| 81 | } | ||
| 82 | |||
| 83 | 16 | bool ErmakovASparMatMultTBB::ValidationImpl() { | |
| 84 | 16 | const auto &a = GetInput().A; | |
| 85 | 16 | const auto &b = GetInput().B; | |
| 86 | |||
| 87 |
1/2✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
|
16 | if (a.cols != b.rows) { |
| 88 | return false; | ||
| 89 | } | ||
| 90 | |||
| 91 |
1/2✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
|
16 | if (!ValidateMatrix(a)) { |
| 92 | return false; | ||
| 93 | } | ||
| 94 | |||
| 95 | 16 | if (!ValidateMatrix(b)) { | |
| 96 | return false; | ||
| 97 | } | ||
| 98 | |||
| 99 | return true; | ||
| 100 | } | ||
| 101 | |||
| 102 | 16 | bool ErmakovASparMatMultTBB::PreProcessingImpl() { | |
| 103 | 16 | a_ = GetInput().A; | |
| 104 | 16 | b_ = GetInput().B; | |
| 105 | |||
| 106 | 16 | c_.rows = a_.rows; | |
| 107 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
|
16 | c_.cols = b_.cols; |
| 108 | c_.values.clear(); | ||
| 109 | c_.col_index.clear(); | ||
| 110 | 16 | c_.row_ptr.assign(static_cast<std::size_t>(c_.rows) + 1, 0); | |
| 111 | |||
| 112 | 16 | return true; | |
| 113 | } | ||
| 114 | |||
| 115 |
2/2✓ Branch 0 taken 198 times.
✓ Branch 1 taken 54 times.
|
252 | void ErmakovASparMatMultTBB::AccumulateRowProducts(int row_index, std::vector<std::complex<double>> &row_vals, |
| 116 | std::vector<int> &row_mark, std::vector<int> &used_cols) const { | ||
| 117 | used_cols.clear(); | ||
| 118 | |||
| 119 | 252 | const int a_start = a_.row_ptr[row_index]; | |
| 120 | 252 | const int a_end = a_.row_ptr[row_index + 1]; | |
| 121 | |||
| 122 |
2/2✓ Branch 0 taken 2816 times.
✓ Branch 1 taken 252 times.
|
3068 | for (int ak = a_start; ak < a_end; ++ak) { |
| 123 | 2816 | const int j = a_.col_index[ak]; | |
| 124 | 2816 | const auto a_ij = a_.values[ak]; | |
| 125 | |||
| 126 | 2816 | const int b_start = b_.row_ptr[j]; | |
| 127 | 2816 | const int b_end = b_.row_ptr[j + 1]; | |
| 128 | |||
| 129 |
2/2✓ Branch 0 taken 54156 times.
✓ Branch 1 taken 2816 times.
|
56972 | for (int bk = b_start; bk < b_end; ++bk) { |
| 130 |
2/2✓ Branch 0 taken 4320 times.
✓ Branch 1 taken 49836 times.
|
54156 | const int k = b_.col_index[bk]; |
| 131 | 54156 | const auto b_jk = b_.values[bk]; | |
| 132 | |||
| 133 |
2/2✓ Branch 0 taken 4320 times.
✓ Branch 1 taken 49836 times.
|
54156 | if (row_mark[k] != row_index) { |
| 134 |
1/2✓ Branch 0 taken 4320 times.
✗ Branch 1 not taken.
|
4320 | row_mark[k] = row_index; |
| 135 |
1/2✓ Branch 0 taken 4320 times.
✗ Branch 1 not taken.
|
4320 | row_vals[k] = a_ij * b_jk; |
| 136 | used_cols.push_back(k); | ||
| 137 | } else { | ||
| 138 | row_vals[k] += a_ij * b_jk; | ||
| 139 | } | ||
| 140 | } | ||
| 141 | } | ||
| 142 | 252 | } | |
| 143 | |||
| 144 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 252 times.
|
252 | void ErmakovASparMatMultTBB::CollectRowValues(const std::vector<std::complex<double>> &row_vals, |
| 145 | const std::vector<int> &used_cols, std::vector<int> &cols, | ||
| 146 | std::vector<std::complex<double>> &vals) { | ||
| 147 | cols.clear(); | ||
| 148 | vals.clear(); | ||
| 149 | |||
| 150 | 252 | cols.reserve(used_cols.size()); | |
| 151 | 252 | vals.reserve(used_cols.size()); | |
| 152 | |||
| 153 |
2/2✓ Branch 0 taken 4320 times.
✓ Branch 1 taken 252 times.
|
4572 | for (int col : used_cols) { |
| 154 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4320 times.
|
4320 | const auto &v = row_vals[static_cast<std::size_t>(col)]; |
| 155 | if (v != std::complex<double>(0.0, 0.0)) { | ||
| 156 | cols.push_back(col); | ||
| 157 | vals.push_back(v); | ||
| 158 | } | ||
| 159 | } | ||
| 160 | 252 | } | |
| 161 | |||
| 162 | ✗ | void ErmakovASparMatMultTBB::SortUsedCols(std::vector<int> &cols) { | |
| 163 | std::ranges::sort(cols); | ||
| 164 | ✗ | } | |
| 165 | |||
| 166 | 16 | bool ErmakovASparMatMultTBB::RunImpl() { | |
| 167 | 16 | const int m = a_.rows; | |
| 168 | 16 | const int p = b_.cols; | |
| 169 | |||
| 170 |
1/2✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
|
16 | if (a_.cols != b_.rows) { |
| 171 | return false; | ||
| 172 | } | ||
| 173 | |||
| 174 | c_.values.clear(); | ||
| 175 | c_.col_index.clear(); | ||
| 176 | std::ranges::fill(c_.row_ptr, 0); | ||
| 177 | |||
| 178 |
2/4✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 16 times.
|
16 | if (m == 0 || p == 0) { |
| 179 | return true; | ||
| 180 | } | ||
| 181 | |||
| 182 | 16 | std::vector<std::vector<std::complex<double>>> row_values(static_cast<std::size_t>(m)); | |
| 183 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | std::vector<std::vector<int>> row_cols(static_cast<std::size_t>(m)); |
| 184 |
1/2✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
|
44 | tbb::enumerable_thread_specific<RowWorkspace> workspace([&] { return RowWorkspace(p); }); |
| 185 | const int grain_size = ResolveGrainSize(m); | ||
| 186 | |||
| 187 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | tbb::parallel_for(tbb::blocked_range<int>(0, m, grain_size), [&](const tbb::blocked_range<int> &range) { |
| 188 | 252 | auto &local = workspace.local(); | |
| 189 | |||
| 190 |
2/2✓ Branch 0 taken 252 times.
✓ Branch 1 taken 252 times.
|
504 | for (int i = range.begin(); i != range.end(); ++i) { |
| 191 | 252 | AccumulateRowProducts(i, local.row_vals, local.row_mark, local.used_cols); | |
| 192 | SortUsedCols(local.used_cols); | ||
| 193 | |||
| 194 | 252 | const auto row_i = static_cast<std::size_t>(i); | |
| 195 | 252 | CollectRowValues(local.row_vals, local.used_cols, row_cols[row_i], row_values[row_i]); | |
| 196 | } | ||
| 197 | 252 | }); | |
| 198 | |||
| 199 | int nnz = 0; | ||
| 200 |
2/2✓ Branch 0 taken 252 times.
✓ Branch 1 taken 16 times.
|
268 | for (int i = 0; i < m; ++i) { |
| 201 | 252 | const auto row_i = static_cast<std::size_t>(i); | |
| 202 | 252 | c_.row_ptr[row_i] = nnz; | |
| 203 | 252 | nnz += static_cast<int>(row_values[row_i].size()); | |
| 204 | } | ||
| 205 | |||
| 206 | 16 | c_.row_ptr[static_cast<std::size_t>(m)] = nnz; | |
| 207 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | c_.values.reserve(static_cast<std::size_t>(nnz)); |
| 208 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | c_.col_index.reserve(static_cast<std::size_t>(nnz)); |
| 209 | |||
| 210 |
2/2✓ Branch 0 taken 252 times.
✓ Branch 1 taken 16 times.
|
268 | for (int i = 0; i < m; ++i) { |
| 211 |
1/2✓ Branch 1 taken 252 times.
✗ Branch 2 not taken.
|
252 | const auto row_i = static_cast<std::size_t>(i); |
| 212 |
1/2✓ Branch 1 taken 252 times.
✗ Branch 2 not taken.
|
252 | c_.values.insert(c_.values.end(), row_values[row_i].begin(), row_values[row_i].end()); |
| 213 | 252 | c_.col_index.insert(c_.col_index.end(), row_cols[row_i].begin(), row_cols[row_i].end()); | |
| 214 | } | ||
| 215 | |||
| 216 | return true; | ||
| 217 | 16 | } | |
| 218 | |||
| 219 | 16 | bool ErmakovASparMatMultTBB::PostProcessingImpl() { | |
| 220 | 16 | GetOutput() = c_; | |
| 221 | 16 | return true; | |
| 222 | } | ||
| 223 | |||
| 224 | } // namespace ermakov_a_spar_mat_mult | ||
| 225 |