GCC Code Coverage Report


Directory: ./
File: tasks/tabalaev_a_matrix_mul_strassen/tbb/src/ops_tbb.cpp
Date: 2026-05-11 08:26:31
Exec Total Coverage
Lines: 152 171 88.9%
Functions: 16 18 88.9%
Branches: 97 178 54.5%

Line Branch Exec Source
1 #include "tabalaev_a_matrix_mul_strassen/tbb/include/ops_tbb.hpp"
2
3 #include <tbb/tbb.h>
4
5 #include <algorithm>
6 #include <cmath>
7 #include <cstddef>
8 #include <utility>
9 #include <vector>
10
11 #include "tabalaev_a_matrix_mul_strassen/common/include/common.hpp"
12
13 namespace tabalaev_a_matrix_mul_strassen {
14
15 static constexpr size_t kBaseCaseSize = 128;
16 static constexpr size_t kParallelThreshold = 262144;
17
18
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 TabalaevAMatrixMulStrassenTBB::TabalaevAMatrixMulStrassenTBB(const InType &in) {
19 SetTypeOfTask(GetStaticTypeOfTask());
20
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 GetInput() = in;
21 GetOutput() = {};
22 24 }
23
24 24 bool TabalaevAMatrixMulStrassenTBB::ValidationImpl() {
25 const auto &in = GetInput();
26
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 24 times.
24 return (in.a_rows > 0) && (in.a_cols_b_rows > 0) && (in.b_cols > 0) &&
27
2/4
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 24 times.
48 (in.a.size() == static_cast<size_t>(in.a_rows * in.a_cols_b_rows)) &&
28
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
24 (in.b.size() == static_cast<size_t>(in.a_cols_b_rows * in.b_cols));
29 }
30
31 24 bool TabalaevAMatrixMulStrassenTBB::PreProcessingImpl() {
32 const auto &in = GetInput();
33
34 24 a_rows_ = in.a_rows;
35 24 a_cols_b_rows_ = in.a_cols_b_rows;
36 24 b_cols_ = in.b_cols;
37
38 24 size_t max_dim = std::max({a_rows_, a_cols_b_rows_, b_cols_});
39
40 24 padded_n_ = 1;
41
2/2
✓ Branch 0 taken 112 times.
✓ Branch 1 taken 24 times.
136 while (padded_n_ < max_dim) {
42 112 padded_n_ *= 2;
43 }
44
45 24 padded_a_.assign(padded_n_ * padded_n_, 0.0);
46 24 padded_b_.assign(padded_n_ * padded_n_, 0.0);
47
48 24 const size_t n = padded_n_;
49
50 24 tbb::parallel_for(static_cast<size_t>(0), a_rows_, [this, n, &in](size_t i) {
51
2/2
✓ Branch 0 taken 523668 times.
✓ Branch 1 taken 2196 times.
525864 for (size_t j = 0; j < a_cols_b_rows_; ++j) {
52 523668 padded_a_[(i * n) + j] = in.a[(i * a_cols_b_rows_) + j];
53 }
54 2196 });
55
56 24 tbb::parallel_for(static_cast<size_t>(0), a_cols_b_rows_, [this, n, &in](size_t i) {
57
2/2
✓ Branch 0 taken 523668 times.
✓ Branch 1 taken 2156 times.
525824 for (size_t j = 0; j < b_cols_; ++j) {
58 523668 padded_b_[(i * n) + j] = in.b[(i * b_cols_) + j];
59 }
60 2156 });
61
62 24 return true;
63 }
64
65 24 bool TabalaevAMatrixMulStrassenTBB::RunImpl() {
66 48 result_c_ = StrassenMultiply(padded_a_, padded_b_, padded_n_);
67
68 auto &out = GetOutput();
69 24 out.assign(a_rows_ * b_cols_, 0.0);
70
71 24 const size_t n = padded_n_;
72
73 24 tbb::parallel_for(static_cast<size_t>(0), a_rows_, [this, n, &out](size_t i) {
74
2/2
✓ Branch 0 taken 524268 times.
✓ Branch 1 taken 2196 times.
526464 for (size_t j = 0; j < b_cols_; ++j) {
75 524268 out[(i * b_cols_) + j] = result_c_[(i * n) + j];
76 }
77 2196 });
78
79 24 return true;
80 }
81
82 24 bool TabalaevAMatrixMulStrassenTBB::PostProcessingImpl() {
83 24 return true;
84 }
85
86 48 void TabalaevAMatrixMulStrassenTBB::Add(const std::vector<double> &mat_a, const std::vector<double> &mat_b,
87 std::vector<double> &res) {
88 const size_t n = mat_a.size();
89 48 res.resize(n);
90
91
1/2
✓ Branch 0 taken 48 times.
✗ Branch 1 not taken.
48 if (n >= kParallelThreshold) {
92 tbb::parallel_for(static_cast<size_t>(0), n, [&](size_t i) { res[i] = mat_a[i] + mat_b[i]; });
93 } else {
94
2/2
✓ Branch 0 taken 786432 times.
✓ Branch 1 taken 48 times.
786480 for (size_t i = 0; i < n; ++i) {
95 786432 res[i] = mat_a[i] + mat_b[i];
96 }
97 }
98 48 }
99
100 32 void TabalaevAMatrixMulStrassenTBB::Subtract(const std::vector<double> &mat_a, const std::vector<double> &mat_b,
101 std::vector<double> &res) {
102 const size_t n = mat_a.size();
103 32 res.resize(n);
104
105
1/2
✓ Branch 0 taken 32 times.
✗ Branch 1 not taken.
32 if (n >= kParallelThreshold) {
106 tbb::parallel_for(static_cast<size_t>(0), n, [&](size_t i) { res[i] = mat_a[i] - mat_b[i]; });
107 } else {
108
2/2
✓ Branch 0 taken 524288 times.
✓ Branch 1 taken 32 times.
524320 for (size_t i = 0; i < n; ++i) {
109 524288 res[i] = mat_a[i] - mat_b[i];
110 }
111 }
112 32 }
113
114 72 std::vector<double> TabalaevAMatrixMulStrassenTBB::BaseMultiply(const std::vector<double> &mat_a,
115 const std::vector<double> &mat_b, size_t n) {
116 72 std::vector<double> res(n * n, 0.0);
117
118
1/4
✓ Branch 1 taken 72 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
72 tbb::parallel_for(static_cast<size_t>(0), n, [&](size_t i) {
119
2/2
✓ Branch 0 taken 919680 times.
✓ Branch 1 taken 7328 times.
927008 for (size_t k = 0; k < n; ++k) {
120
2/2
✓ Branch 0 taken 915988 times.
✓ Branch 1 taken 3692 times.
919680 double temp = mat_a[(i * n) + k];
121
2/2
✓ Branch 0 taken 915988 times.
✓ Branch 1 taken 3692 times.
919680 if (temp == 0.0) {
122 3692 continue;
123 }
124
2/2
✓ Branch 0 taken 117088112 times.
✓ Branch 1 taken 915988 times.
118004100 for (size_t j = 0; j < n; ++j) {
125 117088112 res[(i * n) + j] += temp * mat_b[(k * n) + j];
126 }
127 }
128 7328 });
129
130 72 return res;
131 }
132
133 16 void TabalaevAMatrixMulStrassenTBB::SplitMatrix(const std::vector<double> &src, size_t n, std::vector<double> &c11,
134 std::vector<double> &c12, std::vector<double> &c21,
135 std::vector<double> &c22) {
136 16 size_t h = n / 2;
137 16 size_t sz = h * h;
138
139 16 c11.resize(sz);
140 16 c12.resize(sz);
141 16 c21.resize(sz);
142 16 c22.resize(sz);
143
144
1/2
✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
16 if (n * n >= kParallelThreshold) {
145 tbb::parallel_for(static_cast<size_t>(0), h, [&](size_t i) {
146 for (size_t j = 0; j < h; ++j) {
147 size_t src_idx = (i * n) + j;
148 size_t dst_idx = (i * h) + j;
149
150 c11[dst_idx] = src[src_idx];
151 c12[dst_idx] = src[src_idx + h];
152 c21[dst_idx] = src[src_idx + (h * n)];
153 c22[dst_idx] = src[src_idx + (h * n) + h];
154 }
155 });
156 } else {
157
2/2
✓ Branch 0 taken 2048 times.
✓ Branch 1 taken 16 times.
2064 for (size_t i = 0; i < h; ++i) {
158
2/2
✓ Branch 0 taken 262144 times.
✓ Branch 1 taken 2048 times.
264192 for (size_t j = 0; j < h; ++j) {
159 262144 size_t src_idx = (i * n) + j;
160 262144 size_t dst_idx = (i * h) + j;
161
162 262144 c11[dst_idx] = src[src_idx];
163 262144 c12[dst_idx] = src[src_idx + h];
164 262144 c21[dst_idx] = src[src_idx + (h * n)];
165 262144 c22[dst_idx] = src[src_idx + (h * n) + h];
166 }
167 }
168 }
169 16 }
170
171 8 std::vector<double> TabalaevAMatrixMulStrassenTBB::CombineMatrix(const std::vector<double> &c11,
172 const std::vector<double> &c12,
173 const std::vector<double> &c21,
174 const std::vector<double> &c22, size_t n) {
175 8 size_t h = n / 2;
176
177 8 std::vector<double> res(n * n);
178
179
1/2
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
8 if (n * n >= kParallelThreshold) {
180 tbb::parallel_for(static_cast<size_t>(0), h, [&](size_t i) {
181 for (size_t j = 0; j < h; ++j) {
182 size_t src_idx = (i * h) + j;
183
184 res[(i * n) + j] = c11[src_idx];
185 res[(i * n) + j + h] = c12[src_idx];
186 res[((i + h) * n) + j] = c21[src_idx];
187 res[((i + h) * n) + j + h] = c22[src_idx];
188 }
189 });
190 } else {
191
2/2
✓ Branch 0 taken 1024 times.
✓ Branch 1 taken 8 times.
1032 for (size_t i = 0; i < h; ++i) {
192
2/2
✓ Branch 0 taken 131072 times.
✓ Branch 1 taken 1024 times.
132096 for (size_t j = 0; j < h; ++j) {
193 131072 size_t src_idx = (i * h) + j;
194
195 131072 res[(i * n) + j] = c11[src_idx];
196 131072 res[(i * n) + j + h] = c12[src_idx];
197 131072 res[((i + h) * n) + j] = c21[src_idx];
198 131072 res[((i + h) * n) + j + h] = c22[src_idx];
199 }
200 }
201 }
202
203 8 return res;
204 }
205
206 24 std::vector<double> TabalaevAMatrixMulStrassenTBB::StrassenMultiply(const std::vector<double> &mat_a,
207 const std::vector<double> &mat_b, size_t n) {
208 24 std::vector<StrassenFrameTBB> frames;
209 24 std::vector<std::vector<double>> results;
210
211
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 frames.reserve(64);
212
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 results.reserve(64);
213
214 24 frames.push_back({mat_a, mat_b, n, 0});
215
216 24 std::vector<double> t1;
217 24 std::vector<double> t2;
218
219
2/2
✓ Branch 0 taken 88 times.
✓ Branch 1 taken 24 times.
112 while (!frames.empty()) {
220 StrassenFrameTBB current = std::move(frames.back());
221 frames.pop_back();
222
223
2/2
✓ Branch 0 taken 72 times.
✓ Branch 1 taken 16 times.
88 if (current.n <= kBaseCaseSize) {
224
1/2
✓ Branch 1 taken 72 times.
✗ Branch 2 not taken.
72 results.push_back(BaseMultiply(current.mat_a, current.mat_b, current.n));
225 continue;
226 }
227
228
2/2
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
16 if (current.stage == 8) {
229
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 std::vector<std::vector<double>> p(7);
230
231
2/2
✓ Branch 0 taken 56 times.
✓ Branch 1 taken 8 times.
64 for (int i = 6; i >= 0; --i) {
232 56 p[i] = std::move(results.back());
233 56 results.pop_back();
234 }
235
236 8 size_t h = current.n / 2;
237 8 size_t sz = h * h;
238
239
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 std::vector<double> c11(sz);
240
1/4
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
8 std::vector<double> c12(sz);
241
1/4
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
8 std::vector<double> c21(sz);
242
1/4
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
8 std::vector<double> c22(sz);
243
244
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 tbb::parallel_for(static_cast<size_t>(0), sz, [&](size_t i) {
245 131072 c11[i] = p[0][i] + p[3][i] - p[4][i] + p[6][i];
246 131072 c12[i] = p[2][i] + p[4][i];
247 131072 c21[i] = p[1][i] + p[3][i];
248 131072 c22[i] = p[0][i] - p[1][i] + p[2][i] + p[5][i];
249 131072 });
250
251
2/6
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
16 results.push_back(CombineMatrix(c11, c12, c21, c22, current.n));
252 8 } else {
253 8 size_t h = current.n / 2;
254
255 8 std::vector<double> a11;
256 8 std::vector<double> a12;
257 8 std::vector<double> a21;
258 8 std::vector<double> a22;
259 8 std::vector<double> b11;
260 8 std::vector<double> b12;
261 8 std::vector<double> b21;
262 8 std::vector<double> b22;
263
264
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 SplitMatrix(current.mat_a, current.n, a11, a12, a21, a22);
265
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 SplitMatrix(current.mat_b, current.n, b11, b12, b21, b22);
266
267
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 frames.push_back({{}, {}, current.n, 8});
268
269
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Subtract(a12, a22, t1);
270
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Add(b21, b22, t2);
271 8 frames.push_back({t1, t2, h, 0});
272
273
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Subtract(a21, a11, t1);
274
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Add(b11, b12, t2);
275 8 frames.push_back({t1, t2, h, 0});
276
277
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Add(a11, a12, t1);
278 8 frames.push_back({t1, b22, h, 0});
279
280
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Subtract(b21, b11, t2);
281 8 frames.push_back({a22, t2, h, 0});
282
283
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Subtract(b12, b22, t2);
284 8 frames.push_back({a11, t2, h, 0});
285
286
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Add(a21, a22, t1);
287 8 frames.push_back({t1, b11, h, 0});
288
289
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Add(a11, a22, t1);
290
1/2
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
8 Add(b11, b22, t2);
291
1/4
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
8 frames.push_back({t1, t2, h, 0});
292 }
293 88 }
294
295 24 return std::move(results.back());
296
24/48
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 8 times.
✗ Branch 14 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 17 not taken.
✓ Branch 19 taken 8 times.
✗ Branch 20 not taken.
✓ Branch 22 taken 8 times.
✗ Branch 23 not taken.
✓ Branch 25 taken 8 times.
✗ Branch 26 not taken.
✓ Branch 28 taken 8 times.
✗ Branch 29 not taken.
✓ Branch 31 taken 8 times.
✗ Branch 32 not taken.
✓ Branch 34 taken 8 times.
✗ Branch 35 not taken.
✓ Branch 37 taken 8 times.
✗ Branch 38 not taken.
✓ Branch 40 taken 8 times.
✗ Branch 41 not taken.
✓ Branch 43 taken 8 times.
✗ Branch 44 not taken.
✓ Branch 46 taken 8 times.
✗ Branch 47 not taken.
✓ Branch 49 taken 8 times.
✗ Branch 50 not taken.
✓ Branch 52 taken 8 times.
✗ Branch 53 not taken.
✓ Branch 55 taken 8 times.
✗ Branch 56 not taken.
✓ Branch 58 taken 8 times.
✗ Branch 59 not taken.
✓ Branch 61 taken 8 times.
✗ Branch 62 not taken.
✓ Branch 64 taken 8 times.
✗ Branch 65 not taken.
✓ Branch 67 taken 8 times.
✗ Branch 68 not taken.
✓ Branch 70 taken 8 times.
✗ Branch 71 not taken.
104 }
297
298 } // namespace tabalaev_a_matrix_mul_strassen
299