GCC Code Coverage Report


Directory: ./
File: tasks/tabalaev_a_matrix_mul_strassen/stl/src/ops_stl.cpp
Date: 2026-05-11 08:26:31
Exec Total Coverage
Lines: 148 161 91.9%
Functions: 25 25 100.0%
Branches: 68 156 43.6%

Line Branch Exec Source
1 #include "tabalaev_a_matrix_mul_strassen/stl/include/ops_stl.hpp"
2
3 #include <algorithm>
4 #include <cstddef>
5 #include <functional>
6 #include <stack>
7 #include <thread>
8 #include <utility>
9 #include <vector>
10
11 #include "tabalaev_a_matrix_mul_strassen/common/include/common.hpp"
12 #include "util/include/util.hpp"
13
14 namespace tabalaev_a_matrix_mul_strassen {
15
16 static constexpr std::size_t kBaseCaseSize = 128;
17 static constexpr std::size_t kParallelThreshold = 65536;
18
19 namespace {
20 template <typename fnc>
21 704 void RunParallel(std::size_t begin, std::size_t end, std::size_t threshold, const fnc &func) {
22 704 std::size_t total = end - begin;
23
24
1/2
✓ Branch 0 taken 352 times.
✗ Branch 1 not taken.
704 if (total < threshold) {
25
2/2
✓ Branch 0 taken 296040 times.
✓ Branch 1 taken 352 times.
592784 for (std::size_t i = begin; i < end; ++i) {
26 592080 func(i);
27 }
28 704 return;
29 }
30
31 auto num_threads = static_cast<std::size_t>(ppc::util::GetNumThreads());
32
33 num_threads = std::min(num_threads, total);
34
35 std::vector<std::thread> threads;
36 std::size_t chunk_size = total / num_threads;
37
38 for (std::size_t i = 0; i < num_threads; ++i) {
39 std::size_t current_start = begin + (i * chunk_size);
40
41 std::size_t current_end = (i == num_threads - 1) ? end : current_start + chunk_size;
42
43 threads.emplace_back([current_start, current_end, &func]() {
44 for (std::size_t j = current_start; j < current_end; ++j) {
45 func(j);
46 }
47 });
48 }
49
50 for (auto &t : threads) {
51 t.join();
52 }
53 }
54
55 } // namespace
56
57
1/2
✓ Branch 1 taken 48 times.
✗ Branch 2 not taken.
48 TabalaevAMatrixMulStrassenSTL::TabalaevAMatrixMulStrassenSTL(const InType &in) {
58 SetTypeOfTask(GetStaticTypeOfTask());
59
1/2
✓ Branch 1 taken 48 times.
✗ Branch 2 not taken.
48 GetInput() = in;
60 GetOutput() = {};
61 48 }
62
63 48 bool TabalaevAMatrixMulStrassenSTL::ValidationImpl() {
64 const auto &in = GetInput();
65
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 48 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 48 times.
48 return in.a_rows > 0 && in.a_cols_b_rows > 0 && in.b_cols > 0 &&
66
2/4
✓ Branch 0 taken 48 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 48 times.
96 in.a.size() == static_cast<size_t>(in.a_rows * in.a_cols_b_rows) &&
67
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 48 times.
48 in.b.size() == static_cast<size_t>(in.a_cols_b_rows * in.b_cols);
68 }
69
70 48 bool TabalaevAMatrixMulStrassenSTL::PreProcessingImpl() {
71 GetOutput() = {};
72 const auto &in = GetInput();
73
74 48 a_rows_ = in.a_rows;
75 48 a_cols_b_rows_ = in.a_cols_b_rows;
76 48 b_cols_ = in.b_cols;
77
78 48 std::size_t max_dim = std::max({a_rows_, a_cols_b_rows_, b_cols_});
79 48 padded_n_ = 1;
80
2/2
✓ Branch 0 taken 224 times.
✓ Branch 1 taken 48 times.
272 while (padded_n_ < max_dim) {
81 224 padded_n_ *= 2;
82 }
83
84 48 padded_a_.assign(padded_n_ * padded_n_, 0.0);
85 48 padded_b_.assign(padded_n_ * padded_n_, 0.0);
86
87 48 RunParallel(0, a_rows_, kParallelThreshold, [&](std::size_t i) {
88 4392 auto src_start = in.a.begin() + static_cast<ptrdiff_t>(i * a_cols_b_rows_);
89 4392 auto dst_start = padded_a_.begin() + static_cast<ptrdiff_t>(i * padded_n_);
90 4392 std::copy(src_start, src_start + static_cast<ptrdiff_t>(a_cols_b_rows_), dst_start);
91 4392 });
92
93 48 RunParallel(0, a_cols_b_rows_, kParallelThreshold, [&](std::size_t i) {
94 4312 auto src_start = in.b.begin() + static_cast<ptrdiff_t>(i * b_cols_);
95 4312 auto dst_start = padded_b_.begin() + static_cast<ptrdiff_t>(i * padded_n_);
96 4312 std::copy(src_start, src_start + static_cast<ptrdiff_t>(b_cols_), dst_start);
97 4312 });
98
99 48 return true;
100 }
101
102 48 bool TabalaevAMatrixMulStrassenSTL::RunImpl() {
103 96 result_c_ = StrassenMultiply(padded_a_, padded_b_, padded_n_);
104
105 auto &out = GetOutput();
106 48 out.assign(a_rows_ * b_cols_, 0.0);
107
108 48 RunParallel(0, a_rows_, kParallelThreshold, [&](std::size_t i) {
109 4392 auto src_start = result_c_.begin() + static_cast<ptrdiff_t>(i * padded_n_);
110 4392 auto dst_start = out.begin() + static_cast<ptrdiff_t>(i * b_cols_);
111 4392 std::copy(src_start, src_start + static_cast<ptrdiff_t>(b_cols_), dst_start);
112 4392 });
113
114 48 return true;
115 }
116
117 48 bool TabalaevAMatrixMulStrassenSTL::PostProcessingImpl() {
118 48 return true;
119 }
120
121 96 std::vector<double> TabalaevAMatrixMulStrassenSTL::Add(const std::vector<double> &mat_a,
122 const std::vector<double> &mat_b) {
123 96 std::vector<double> res(mat_a.size());
124 std::ranges::transform(mat_a, mat_b, res.begin(), std::plus<>());
125 96 return res;
126 }
127
128 64 std::vector<double> TabalaevAMatrixMulStrassenSTL::Subtract(const std::vector<double> &mat_a,
129 const std::vector<double> &mat_b) {
130 64 std::vector<double> res(mat_a.size());
131 std::ranges::transform(mat_a, mat_b, res.begin(), std::minus<>());
132 64 return res;
133 }
134
135 144 std::vector<double> TabalaevAMatrixMulStrassenSTL::BaseMultiply(const std::vector<double> &mat_a,
136 const std::vector<double> &mat_b, std::size_t n) {
137
1/2
✓ Branch 2 taken 144 times.
✗ Branch 3 not taken.
144 std::vector<double> res(n * n, 0.0);
138
139 144 const double *a_ptr = mat_a.data();
140
1/2
✓ Branch 1 taken 144 times.
✗ Branch 2 not taken.
144 const double *b_ptr = mat_b.data();
141 144 double *res_ptr = res.data();
142
143
1/4
✓ Branch 1 taken 144 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
144 RunParallel(0, n, kParallelThreshold, [&](std::size_t i) {
144 14656 std::size_t i_n = i * n;
145
2/2
✓ Branch 0 taken 1839360 times.
✓ Branch 1 taken 14656 times.
1854016 for (std::size_t k = 0; k < n; ++k) {
146 1839360 double temp = a_ptr[i_n + k];
147
2/2
✓ Branch 0 taken 7384 times.
✓ Branch 1 taken 1831976 times.
1839360 if (temp == 0.0) {
148 7384 continue;
149 }
150 1831976 std::size_t k_n = k * n;
151
2/2
✓ Branch 0 taken 234176224 times.
✓ Branch 1 taken 1831976 times.
236008200 for (std::size_t j = 0; j < n; ++j) {
152 234176224 res_ptr[i_n + j] += temp * b_ptr[k_n + j];
153 }
154 }
155 14656 });
156
157 144 return res;
158 }
159
160 32 void TabalaevAMatrixMulStrassenSTL::SplitMatrix(const std::vector<double> &src, std::size_t n, std::vector<double> &c11,
161 std::vector<double> &c12, std::vector<double> &c21,
162 std::vector<double> &c22) {
163 32 std::size_t h = n / 2;
164 32 std::size_t sz = h * h;
165 32 c11.resize(sz);
166 32 c12.resize(sz);
167 32 c21.resize(sz);
168 32 c22.resize(sz);
169
170 32 RunParallel(0, h, kParallelThreshold, [&](std::size_t i) {
171 4096 auto src_row1 = src.begin() + static_cast<ptrdiff_t>(i * n);
172 4096 auto src_row2 = src.begin() + static_cast<ptrdiff_t>((i + h) * n);
173 4096 auto dst_row = static_cast<ptrdiff_t>(i * h);
174
175 4096 std::copy(src_row1, src_row1 + static_cast<ptrdiff_t>(h), c11.begin() + dst_row);
176 4096 std::copy(src_row1 + static_cast<ptrdiff_t>(h), src_row1 + static_cast<ptrdiff_t>(n), c12.begin() + dst_row);
177
178 4096 std::copy(src_row2, src_row2 + static_cast<ptrdiff_t>(h), c21.begin() + dst_row);
179 4096 std::copy(src_row2 + static_cast<ptrdiff_t>(h), src_row2 + static_cast<ptrdiff_t>(n), c22.begin() + dst_row);
180 4096 });
181 32 }
182
183 16 std::vector<double> TabalaevAMatrixMulStrassenSTL::CombineMatrix(const std::vector<double> &c11,
184 const std::vector<double> &c12,
185 const std::vector<double> &c21,
186 const std::vector<double> &c22, std::size_t n) {
187 16 std::size_t h = n / 2;
188 16 std::vector<double> res(n * n);
189
190
1/4
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
16 RunParallel(0, h, kParallelThreshold, [&](std::size_t i) {
191 2048 auto res_row1 = res.begin() + static_cast<ptrdiff_t>(i * n);
192 2048 auto res_row2 = res.begin() + static_cast<ptrdiff_t>((i + h) * n);
193 2048 auto src_row = static_cast<ptrdiff_t>(i * h);
194
195 2048 std::copy(c11.begin() + src_row, c11.begin() + src_row + static_cast<ptrdiff_t>(h), res_row1);
196 2048 std::copy(c12.begin() + src_row, c12.begin() + src_row + static_cast<ptrdiff_t>(h),
197 2048 res_row1 + static_cast<ptrdiff_t>(h));
198
199 2048 std::copy(c21.begin() + src_row, c21.begin() + src_row + static_cast<ptrdiff_t>(h), res_row2);
200 2048 std::copy(c22.begin() + src_row, c22.begin() + src_row + static_cast<ptrdiff_t>(h),
201 2048 res_row2 + static_cast<ptrdiff_t>(h));
202 2048 });
203 16 return res;
204 }
205
206 48 std::vector<double> TabalaevAMatrixMulStrassenSTL::StrassenMultiply(const std::vector<double> &mat_a,
207 const std::vector<double> &mat_b, std::size_t n) {
208 std::stack<StrassenFrameSTL> frames;
209 std::stack<std::vector<double>> results;
210
211 48 frames.push({mat_a, mat_b, n, 0});
212
213
2/2
✓ Branch 0 taken 176 times.
✓ Branch 1 taken 48 times.
224 while (!frames.empty()) {
214 StrassenFrameSTL current = std::move(frames.top());
215 frames.pop();
216
217
2/2
✓ Branch 0 taken 144 times.
✓ Branch 1 taken 32 times.
176 if (current.n <= kBaseCaseSize) {
218
1/2
✓ Branch 1 taken 144 times.
✗ Branch 2 not taken.
144 results.push(BaseMultiply(current.mat_a, current.mat_b, current.n));
219 continue;
220 }
221
222
2/2
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 16 times.
32 if (current.stage == 8) {
223
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
16 std::vector<std::vector<double>> p(7);
224
225
2/2
✓ Branch 0 taken 112 times.
✓ Branch 1 taken 16 times.
128 for (int i = 6; i >= 0; --i) {
226 112 p[i] = std::move(results.top());
227 results.pop();
228 }
229
230 16 std::size_t h = current.n / 2;
231 16 std::size_t sz = h * h;
232
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
16 std::vector<double> c11(sz);
233
1/4
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
16 std::vector<double> c12(sz);
234
1/4
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
16 std::vector<double> c21(sz);
235
2/6
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
16 std::vector<double> c22(sz);
236
237 16 double *c11_ptr = c11.data();
238 16 double *c12_ptr = c12.data();
239 16 double *c21_ptr = c21.data();
240 16 double *c22_ptr = c22.data();
241
242
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
16 RunParallel(0, sz, kParallelThreshold, [&](std::size_t i) {
243 262144 c11_ptr[i] = p[0][i] + p[3][i] - p[4][i] + p[6][i];
244 262144 c12_ptr[i] = p[2][i] + p[4][i];
245 262144 c21_ptr[i] = p[1][i] + p[3][i];
246 262144 c22_ptr[i] = p[0][i] - p[1][i] + p[2][i] + p[5][i];
247 262144 });
248
249
2/6
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 16 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
32 results.push(CombineMatrix(c11, c12, c21, c22, current.n));
250 16 } else {
251 16 std::size_t h = current.n / 2;
252 16 std::vector<double> a11;
253 16 std::vector<double> a12;
254 16 std::vector<double> a21;
255 16 std::vector<double> a22;
256 16 std::vector<double> b11;
257 16 std::vector<double> b12;
258 16 std::vector<double> b21;
259 16 std::vector<double> b22;
260
261
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
16 SplitMatrix(current.mat_a, current.n, a11, a12, a21, a22);
262
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
16 SplitMatrix(current.mat_b, current.n, b11, b12, b21, b22);
263
264
1/2
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
32 frames.push({{}, {}, current.n, 8});
265
266 16 frames.push({Subtract(a12, a22), Add(b21, b22), h, 0});
267 16 frames.push({Subtract(a21, a11), Add(b11, b12), h, 0});
268 16 frames.push({Add(a11, a12), b22, h, 0});
269 16 frames.push({a22, Subtract(b21, b11), h, 0});
270 16 frames.push({a11, Subtract(b12, b22), h, 0});
271 16 frames.push({Add(a21, a22), b11, h, 0});
272
1/4
✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
16 frames.push({Add(a11, a22), Add(b11, b22), h, 0});
273 }
274 176 }
275
276 48 return std::move(results.top());
277
24/48
✓ Branch 1 taken 48 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 48 times.
✗ Branch 5 not taken.
✓ Branch 7 taken 48 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 16 times.
✗ Branch 11 not taken.
✓ Branch 13 taken 16 times.
✗ Branch 14 not taken.
✓ Branch 16 taken 16 times.
✗ Branch 17 not taken.
✓ Branch 19 taken 16 times.
✗ Branch 20 not taken.
✓ Branch 22 taken 16 times.
✗ Branch 23 not taken.
✓ Branch 25 taken 16 times.
✗ Branch 26 not taken.
✓ Branch 28 taken 16 times.
✗ Branch 29 not taken.
✓ Branch 31 taken 16 times.
✗ Branch 32 not taken.
✓ Branch 34 taken 16 times.
✗ Branch 35 not taken.
✓ Branch 37 taken 16 times.
✗ Branch 38 not taken.
✓ Branch 40 taken 16 times.
✗ Branch 41 not taken.
✓ Branch 43 taken 16 times.
✗ Branch 44 not taken.
✓ Branch 46 taken 16 times.
✗ Branch 47 not taken.
✓ Branch 49 taken 16 times.
✗ Branch 50 not taken.
✓ Branch 52 taken 16 times.
✗ Branch 53 not taken.
✓ Branch 55 taken 16 times.
✗ Branch 56 not taken.
✓ Branch 58 taken 16 times.
✗ Branch 59 not taken.
✓ Branch 61 taken 16 times.
✗ Branch 62 not taken.
✓ Branch 64 taken 16 times.
✗ Branch 65 not taken.
✓ Branch 67 taken 16 times.
✗ Branch 68 not taken.
✓ Branch 70 taken 16 times.
✗ Branch 71 not taken.
160 }
278
279 } // namespace tabalaev_a_matrix_mul_strassen
280