GCC Code Coverage Report


Directory: ./
File: tasks/sannikov_i_shtrassen_algorithm/mpi/src/ops_mpi.cpp
Date: 2026-01-10 02:40:41
Exec Total Coverage
Lines: 351 355 98.9%
Functions: 29 29 100.0%
Branches: 206 330 62.4%

Line Branch Exec Source
1 #include "sannikov_i_shtrassen_algorithm/mpi/include/ops_mpi.hpp"
2
3 #include <mpi.h>
4
5 #include <cstddef>
6 #include <cstdint>
7 #include <limits>
8 #include <tuple>
9 #include <utility>
10 #include <vector>
11
12 #include "sannikov_i_shtrassen_algorithm/common/include/common.hpp"
13
14 namespace sannikov_i_shtrassen_algorithm {
15
16
1/2
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
24 SannikovIShtrassenAlgorithmMPI::SannikovIShtrassenAlgorithmMPI(const InType &in) {
17 SetTypeOfTask(GetStaticTypeOfTask());
18 auto &input_buffer = GetInput();
19 InType tmp(in);
20 input_buffer.swap(tmp);
21 GetOutput().clear();
22 24 }
23
24
1/2
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
24 bool SannikovIShtrassenAlgorithmMPI::ValidationImpl() {
25 const auto &input = GetInput();
26 const auto &mat_a = std::get<0>(input);
27 const auto &mat_b = std::get<1>(input);
28
29
2/4
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 24 times.
24 if (mat_a.empty() || mat_b.empty()) {
30 return false;
31 }
32
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
24 if (mat_a.size() != mat_b.size()) {
33 return false;
34 }
35
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 24 times.
24 if (mat_a.front().empty() || mat_b.front().empty()) {
36 return false;
37 }
38
39 const auto n = mat_a.size();
40
2/2
✓ Branch 0 taken 82 times.
✓ Branch 1 taken 24 times.
106 for (const auto &row : mat_a) {
41
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 82 times.
82 if (row.size() != n) {
42 return false;
43 }
44 }
45
2/2
✓ Branch 0 taken 82 times.
✓ Branch 1 taken 24 times.
106 for (const auto &row : mat_b) {
46
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 82 times.
82 if (row.size() != n) {
47 return false;
48 }
49 }
50
51 24 return GetOutput().empty();
52 }
53
54
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
24 bool SannikovIShtrassenAlgorithmMPI::PreProcessingImpl() {
55 GetOutput().clear();
56 24 return true;
57 }
58
59 namespace {
60
61 using Flat = std::vector<double>;
62 using Matrix = std::vector<std::vector<double>>;
63
64 constexpr std::size_t kClassicThreshold = 2;
65
66 std::size_t NextPow2(std::size_t value) {
67 std::size_t pow2 = 1;
68
2/2
✓ Branch 0 taken 40 times.
✓ Branch 1 taken 24 times.
64 while (pow2 < value) {
69 40 pow2 <<= 1U;
70 }
71 return pow2;
72 }
73
74 bool SizeOkU64(std::uint64_t n64) {
75 24 if (n64 == 0U) {
76 return false;
77 }
78
1/2
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
24 if (n64 > static_cast<std::uint64_t>(std::numeric_limits<int>::max())) {
79 return false;
80 }
81 return true;
82 }
83
84 std::size_t Idx(std::size_t row, std::size_t col, std::size_t ld) {
85 1428 return (row * ld) + col;
86 }
87
88 196 Flat MultiplyClassicFlat(const Flat &a, const Flat &b, std::size_t n) {
89 196 Flat c(n * n, 0.0);
90
2/2
✓ Branch 0 taken 371 times.
✓ Branch 1 taken 196 times.
567 for (std::size_t row = 0; row < n; ++row) {
91 double *crow = &c[Idx(row, 0, n)];
92
2/2
✓ Branch 0 taken 721 times.
✓ Branch 1 taken 371 times.
1092 for (std::size_t mid = 0; mid < n; ++mid) {
93 721 const double aik = a[Idx(row, mid, n)];
94 const double *brow = &b[Idx(mid, 0, n)];
95
2/2
✓ Branch 0 taken 1421 times.
✓ Branch 1 taken 721 times.
2142 for (std::size_t col = 0; col < n; ++col) {
96 1421 crow[col] += aik * brow[col];
97 }
98 }
99 }
100 196 return c;
101 }
102
103 168 Flat ExtractBlockFlat(const Flat &src, std::size_t src_n, std::size_t row0, std::size_t col0, std::size_t blk_n) {
104 168 Flat blk(blk_n * blk_n, 0.0);
105
2/2
✓ Branch 0 taken 336 times.
✓ Branch 1 taken 168 times.
504 for (std::size_t row = 0; row < blk_n; ++row) {
106
2/2
✓ Branch 0 taken 672 times.
✓ Branch 1 taken 336 times.
1008 for (std::size_t col = 0; col < blk_n; ++col) {
107 672 blk[Idx(row, col, blk_n)] = src[Idx(row + row0, col + col0, src_n)];
108 }
109 }
110 168 return blk;
111 }
112
113 void PlaceBlockFlat(const Flat &blk, std::size_t blk_n, Flat *dst, std::size_t dst_n, std::size_t row0,
114 std::size_t col0) {
115
8/8
✓ Branch 0 taken 42 times.
✓ Branch 1 taken 21 times.
✓ Branch 2 taken 42 times.
✓ Branch 3 taken 21 times.
✓ Branch 4 taken 42 times.
✓ Branch 5 taken 21 times.
✓ Branch 6 taken 42 times.
✓ Branch 7 taken 21 times.
252 for (std::size_t row = 0; row < blk_n; ++row) {
116
8/8
✓ Branch 0 taken 84 times.
✓ Branch 1 taken 42 times.
✓ Branch 2 taken 84 times.
✓ Branch 3 taken 42 times.
✓ Branch 4 taken 84 times.
✓ Branch 5 taken 42 times.
✓ Branch 6 taken 84 times.
✓ Branch 7 taken 42 times.
504 for (std::size_t col = 0; col < blk_n; ++col) {
117 336 (*dst)[Idx(row + row0, col + col0, dst_n)] = blk[Idx(row, col, blk_n)];
118 }
119 }
120 }
121
122 126 void AddFlat(const Flat &a, const Flat &b, Flat *out) {
123 const std::size_t n = a.size();
124 126 out->assign(n, 0.0);
125
2/2
✓ Branch 0 taken 504 times.
✓ Branch 1 taken 126 times.
630 for (std::size_t i = 0; i < n; ++i) {
126 504 (*out)[i] = a[i] + b[i];
127 }
128 126 }
129
130 84 void SubFlat(const Flat &a, const Flat &b, Flat *out) {
131 const std::size_t n = a.size();
132 84 out->assign(n, 0.0);
133
2/2
✓ Branch 0 taken 336 times.
✓ Branch 1 taken 84 times.
420 for (std::size_t i = 0; i < n; ++i) {
134 336 (*out)[i] = a[i] - b[i];
135 }
136 84 }
137
138 struct Frame {
139 Flat a;
140 Flat b;
141 std::size_t n = 0;
142
143 bool has_parent = false;
144 std::size_t parent_idx = 0;
145 int parent_slot = 0;
146
147 int stage = 0;
148 Flat res;
149
150 Flat a11, a12, a21, a22;
151 Flat b11, b12, b21, b22;
152
153 Flat m1, m2, m3, m4, m5, m6, m7;
154 };
155
156 147 void AssignToParent(std::vector<Frame> *stack, const Frame &child) {
157
1/2
✓ Branch 0 taken 147 times.
✗ Branch 1 not taken.
147 if (!child.has_parent) {
158 return;
159 }
160
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 126 times.
147 Frame &parent = (*stack)[child.parent_idx];
161 147 const int slot = child.parent_slot;
162
163
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 126 times.
147 if (slot == 1) {
164 21 parent.m1 = child.res;
165 }
166
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 126 times.
147 if (slot == 2) {
167 21 parent.m2 = child.res;
168 }
169
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 126 times.
147 if (slot == 3) {
170 21 parent.m3 = child.res;
171 }
172
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 126 times.
147 if (slot == 4) {
173 21 parent.m4 = child.res;
174 }
175
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 126 times.
147 if (slot == 5) {
176 21 parent.m5 = child.res;
177 }
178
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 126 times.
147 if (slot == 6) {
179 21 parent.m6 = child.res;
180 }
181
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 126 times.
147 if (slot == 7) {
182 21 parent.m7 = child.res;
183 }
184 }
185
186 bool IsLeaf(const Frame &frame) {
187 364 return frame.n <= kClassicThreshold;
188 }
189
190 168 void SplitIfNeeded(Frame *frame) {
191
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 147 times.
168 if (frame->stage != 0) {
192 return;
193 }
194
195 21 const std::size_t half = frame->n / 2;
196
197 21 frame->a11 = ExtractBlockFlat(frame->a, frame->n, 0, 0, half);
198 21 frame->a12 = ExtractBlockFlat(frame->a, frame->n, 0, half, half);
199 21 frame->a21 = ExtractBlockFlat(frame->a, frame->n, half, 0, half);
200 21 frame->a22 = ExtractBlockFlat(frame->a, frame->n, half, half, half);
201
202 21 frame->b11 = ExtractBlockFlat(frame->b, frame->n, 0, 0, half);
203 21 frame->b12 = ExtractBlockFlat(frame->b, frame->n, 0, half, half);
204 21 frame->b21 = ExtractBlockFlat(frame->b, frame->n, half, 0, half);
205 21 frame->b22 = ExtractBlockFlat(frame->b, frame->n, half, half, half);
206
207 21 frame->stage = 1;
208 }
209
210
7/7
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 21 times.
✓ Branch 2 taken 21 times.
✓ Branch 3 taken 21 times.
✓ Branch 4 taken 21 times.
✓ Branch 5 taken 21 times.
✓ Branch 6 taken 21 times.
147 void BuildChildOperands(const Frame &parent, int slot, Flat *left, Flat *right) {
211 if (slot == 1) {
212 21 AddFlat(parent.a11, parent.a22, left);
213 21 AddFlat(parent.b11, parent.b22, right);
214 21 return;
215 }
216 if (slot == 2) {
217 21 AddFlat(parent.a21, parent.a22, left);
218 21 *right = parent.b11;
219 21 return;
220 }
221 if (slot == 3) {
222 21 *left = parent.a11;
223 21 SubFlat(parent.b12, parent.b22, right);
224 21 return;
225 }
226 if (slot == 4) {
227 21 *left = parent.a22;
228 21 SubFlat(parent.b21, parent.b11, right);
229 21 return;
230 }
231 if (slot == 5) {
232 21 AddFlat(parent.a11, parent.a12, left);
233 21 *right = parent.b22;
234 21 return;
235 }
236 if (slot == 6) {
237 21 SubFlat(parent.a21, parent.a11, left);
238 21 AddFlat(parent.b11, parent.b12, right);
239 21 return;
240 }
241 21 SubFlat(parent.a12, parent.a22, left);
242 21 AddFlat(parent.b21, parent.b22, right);
243 }
244
245 147 Frame MakeChild(const Frame &parent, std::size_t parent_idx, int slot) {
246 147 Flat left;
247 147 Flat right;
248
1/2
✓ Branch 1 taken 147 times.
✗ Branch 2 not taken.
147 BuildChildOperands(parent, slot, &left, &right);
249
250 147 Frame child;
251 147 child.a = std::move(left);
252 147 child.b = std::move(right);
253 147 child.n = parent.n / 2;
254
255 147 child.has_parent = true;
256 147 child.parent_idx = parent_idx;
257 147 child.parent_slot = slot;
258
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 147 times.
147 child.stage = 0;
259 147 return child;
260 }
261
262 21 void CombineOnFrame(Frame *frame) {
263 21 const std::size_t half = frame->n / 2;
264 21 const std::size_t kk = half * half;
265
266 21 Flat c11(kk, 0.0);
267
1/4
✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
21 Flat c12(kk, 0.0);
268
1/4
✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
21 Flat c21(kk, 0.0);
269
1/4
✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
21 Flat c22(kk, 0.0);
270
271
2/2
✓ Branch 0 taken 84 times.
✓ Branch 1 taken 21 times.
105 for (std::size_t i = 0; i < kk; ++i) {
272 84 c11[i] = frame->m1[i] + frame->m4[i] - frame->m5[i] + frame->m7[i];
273 }
274
2/2
✓ Branch 0 taken 84 times.
✓ Branch 1 taken 21 times.
105 for (std::size_t i = 0; i < kk; ++i) {
275 84 c12[i] = frame->m3[i] + frame->m5[i];
276 }
277
2/2
✓ Branch 0 taken 84 times.
✓ Branch 1 taken 21 times.
105 for (std::size_t i = 0; i < kk; ++i) {
278 84 c21[i] = frame->m2[i] + frame->m4[i];
279 }
280
2/2
✓ Branch 0 taken 84 times.
✓ Branch 1 taken 21 times.
105 for (std::size_t i = 0; i < kk; ++i) {
281 84 c22[i] = frame->m1[i] - frame->m2[i] + frame->m3[i] + frame->m6[i];
282 }
283
284
1/4
✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
21 frame->res.assign(frame->n * frame->n, 0.0);
285 21 PlaceBlockFlat(c11, half, &frame->res, frame->n, 0, 0);
286 PlaceBlockFlat(c12, half, &frame->res, frame->n, 0, half);
287 PlaceBlockFlat(c21, half, &frame->res, frame->n, half, 0);
288 PlaceBlockFlat(c22, half, &frame->res, frame->n, half, half);
289 21 }
290
291 70 Flat ShtrassenIterativeFlat(const Flat &a0, const Flat &b0, std::size_t n0) {
292 70 std::vector<Frame> frames;
293
1/2
✓ Branch 1 taken 70 times.
✗ Branch 2 not taken.
70 frames.reserve(64);
294
295 70 Frame root;
296
1/2
✓ Branch 1 taken 70 times.
✗ Branch 2 not taken.
70 root.a = a0;
297
1/2
✓ Branch 1 taken 70 times.
✗ Branch 2 not taken.
70 root.b = b0;
298 70 root.n = n0;
299 70 root.has_parent = false;
300
1/2
✓ Branch 1 taken 70 times.
✗ Branch 2 not taken.
70 root.stage = 0;
301 frames.push_back(std::move(root));
302
303
1/2
✓ Branch 0 taken 364 times.
✗ Branch 1 not taken.
364 while (!frames.empty()) {
304 Frame &cur = frames.back();
305
306
2/2
✓ Branch 0 taken 196 times.
✓ Branch 1 taken 168 times.
364 if (IsLeaf(cur)) {
307
1/2
✓ Branch 1 taken 196 times.
✗ Branch 2 not taken.
196 cur.res = MultiplyClassicFlat(cur.a, cur.b, cur.n);
308
309
1/2
✓ Branch 1 taken 196 times.
✗ Branch 2 not taken.
196 const Frame finished = cur;
310 frames.pop_back();
311
312
2/2
✓ Branch 0 taken 49 times.
✓ Branch 1 taken 147 times.
196 if (frames.empty()) {
313
1/2
✓ Branch 1 taken 49 times.
✗ Branch 2 not taken.
49 return finished.res;
314 }
315
1/2
✓ Branch 1 taken 147 times.
✗ Branch 2 not taken.
147 AssignToParent(&frames, finished);
316 continue;
317 196 }
318
319
1/2
✓ Branch 1 taken 168 times.
✗ Branch 2 not taken.
168 SplitIfNeeded(&cur);
320
321
2/2
✓ Branch 0 taken 147 times.
✓ Branch 1 taken 21 times.
168 if ((cur.stage >= 1) && (cur.stage <= 7)) {
322 const int slot = cur.stage;
323 147 const std::size_t parent_idx = frames.size() - 1;
324
325
1/2
✓ Branch 1 taken 147 times.
✗ Branch 2 not taken.
147 Frame child = MakeChild(cur, parent_idx, slot);
326
1/2
✓ Branch 1 taken 147 times.
✗ Branch 2 not taken.
147 cur.stage += 1;
327 frames.push_back(std::move(child));
328 continue;
329 147 }
330
331
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 21 times.
21 if (cur.stage != 8) {
332 cur.stage = 8;
333 continue;
334 }
335
336
1/2
✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
21 CombineOnFrame(&cur);
337
338
1/2
✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
21 const Frame finished = cur;
339 frames.pop_back();
340
341
1/2
✓ Branch 0 taken 21 times.
✗ Branch 1 not taken.
21 if (frames.empty()) {
342
1/2
✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
21 return finished.res;
343 }
344 AssignToParent(&frames, finished);
345 21 }
346
347 return Flat{};
348 70 }
349 70 void MulStrassenOrClassic(const Flat &left, const Flat &right, int k, Flat *out) {
350 70 const auto kk = static_cast<std::size_t>(k);
351 70 *out = ShtrassenIterativeFlat(left, right, kk);
352 70 }
353
354 20 void PadToFlatOnRoot(const Matrix &src, int n0, int m, Flat *flat) {
355 20 flat->assign(static_cast<std::size_t>(m) * static_cast<std::size_t>(m), 0.0);
356 const auto mm = static_cast<std::size_t>(m);
357
2/2
✓ Branch 0 taken 78 times.
✓ Branch 1 taken 20 times.
98 for (int row = 0; row < n0; ++row) {
358
2/2
✓ Branch 0 taken 374 times.
✓ Branch 1 taken 78 times.
452 for (int col = 0; col < n0; ++col) {
359 374 const auto rr = static_cast<std::size_t>(row);
360 374 const auto cc = static_cast<std::size_t>(col);
361 374 (*flat)[Idx(rr, cc, mm)] = src[rr][cc];
362 }
363 }
364 20 }
365
366 160 void ExtractBlock(const Flat &src, int m, int block_row, int block_col, int k, Flat *block) {
367 160 block->assign(static_cast<std::size_t>(k) * static_cast<std::size_t>(k), 0.0);
368
369 160 const auto mm = static_cast<std::size_t>(m);
370 const auto kk = static_cast<std::size_t>(k);
371 160 const auto ro = static_cast<std::size_t>(block_row) * static_cast<std::size_t>(k);
372 160 const auto co = static_cast<std::size_t>(block_col) * static_cast<std::size_t>(k);
373
374
2/2
✓ Branch 0 taken 368 times.
✓ Branch 1 taken 160 times.
528 for (int row = 0; row < k; ++row) {
375
2/2
✓ Branch 0 taken 1072 times.
✓ Branch 1 taken 368 times.
1440 for (int col = 0; col < k; ++col) {
376 1072 const auto rr = ro + static_cast<std::size_t>(row);
377 1072 const auto cc = co + static_cast<std::size_t>(col);
378 1072 (*block)[Idx(static_cast<std::size_t>(row), static_cast<std::size_t>(col), kk)] = src[Idx(rr, cc, mm)];
379 }
380 }
381 160 }
382
383 60 void AddVec(const Flat &a, const Flat &b, Flat *out) {
384 const auto n = a.size();
385 60 out->assign(n, 0.0);
386
2/2
✓ Branch 0 taken 402 times.
✓ Branch 1 taken 60 times.
462 for (std::size_t ii = 0; ii < n; ++ii) {
387 402 (*out)[ii] = a[ii] + b[ii];
388 }
389 60 }
390
391 40 void SubVec(const Flat &a, const Flat &b, Flat *out) {
392 const auto n = a.size();
393 40 out->assign(n, 0.0);
394
2/2
✓ Branch 0 taken 268 times.
✓ Branch 1 taken 40 times.
308 for (std::size_t ii = 0; ii < n; ++ii) {
395 268 (*out)[ii] = a[ii] - b[ii];
396 }
397 40 }
398
399 int OwnerForTask(int task_id, int comm_size) {
400
1/2
✓ Branch 0 taken 140 times.
✗ Branch 1 not taken.
140 if (comm_size <= 0) {
401 return 0;
402 }
403 140 return (task_id - 1) % comm_size;
404 }
405
406 70 void ComputeOneStrassenTask(int task_id, const Flat &a11, const Flat &a12, const Flat &a21, const Flat &a22,
407 const Flat &b11, const Flat &b12, const Flat &b21, const Flat &b22, int k, Flat *out_m) {
408 70 Flat tmp1;
409
7/7
✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
✓ Branch 2 taken 10 times.
✓ Branch 3 taken 10 times.
✓ Branch 4 taken 10 times.
✓ Branch 5 taken 10 times.
✓ Branch 6 taken 10 times.
70 Flat tmp2;
410
411 if (task_id == 1) {
412
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 AddVec(a11, a22, &tmp1);
413
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 AddVec(b11, b22, &tmp2);
414
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 MulStrassenOrClassic(tmp1, tmp2, k, out_m);
415 return;
416 }
417 if (task_id == 2) {
418
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 AddVec(a21, a22, &tmp1);
419
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 MulStrassenOrClassic(tmp1, b11, k, out_m);
420 return;
421 }
422 if (task_id == 3) {
423
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 SubVec(b12, b22, &tmp2);
424
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 MulStrassenOrClassic(a11, tmp2, k, out_m);
425 return;
426 }
427 if (task_id == 4) {
428
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 SubVec(b21, b11, &tmp2);
429
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 MulStrassenOrClassic(a22, tmp2, k, out_m);
430 return;
431 }
432 if (task_id == 5) {
433
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 AddVec(a11, a12, &tmp1);
434
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 MulStrassenOrClassic(tmp1, b22, k, out_m);
435 return;
436 }
437 if (task_id == 6) {
438
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 SubVec(a21, a11, &tmp1);
439
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 AddVec(b11, b12, &tmp2);
440
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 MulStrassenOrClassic(tmp1, tmp2, k, out_m);
441 return;
442 }
443
444
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 SubVec(a12, a22, &tmp1);
445
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 AddVec(b21, b22, &tmp2);
446
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 MulStrassenOrClassic(tmp1, tmp2, k, out_m);
447 }
448
449 10 void AssembleCOnRoot(const Flat &m1, const Flat &m2, const Flat &m3, const Flat &m4, const Flat &m5, const Flat &m6,
450 const Flat &m7, int m, int k, Flat *c_full) {
451 10 c_full->assign(static_cast<std::size_t>(m) * static_cast<std::size_t>(m), 0.0);
452 const auto mm = static_cast<std::size_t>(m);
453 10 const auto kk = static_cast<std::size_t>(k);
454
455
2/2
✓ Branch 0 taken 23 times.
✓ Branch 1 taken 10 times.
33 for (int row = 0; row < k; ++row) {
456 23 const auto rr = static_cast<std::size_t>(row);
457
2/2
✓ Branch 0 taken 67 times.
✓ Branch 1 taken 23 times.
90 for (int col = 0; col < k; ++col) {
458 67 const auto cc = static_cast<std::size_t>(col);
459 const auto id = Idx(rr, cc, kk);
460 67 const double c11 = m1[id] + m4[id] - m5[id] + m7[id];
461 67 const double c12 = m3[id] + m5[id];
462 67 const double c21 = m2[id] + m4[id];
463 67 const double c22 = m1[id] - m2[id] + m3[id] + m6[id];
464 67 (*c_full)[Idx(rr, cc, mm)] = c11;
465 67 (*c_full)[Idx(rr, cc + kk, mm)] = c12;
466 67 (*c_full)[Idx(rr + kk, cc, mm)] = c21;
467 67 (*c_full)[Idx(rr + kk, cc + kk, mm)] = c22;
468 }
469 }
470 10 }
471
472 4 bool FastPath1x1(int rank, const Matrix &a, const Matrix &b, Matrix *out) {
473 4 double value = 0.0;
474
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 2 times.
4 if (rank == 0) {
475 2 value = a[0][0] * b[0][0];
476 }
477 4 MPI_Bcast(&value, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
478
1/2
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 *out = Matrix(1, std::vector<double>(1, value));
479 4 return true;
480 }
481
482 20 void BroadcastFullPads(int rank, int m, const Matrix &a_in, const Matrix &b_in, int n0, Flat *a_full, Flat *b_full) {
483 20 const auto total = static_cast<std::size_t>(m) * static_cast<std::size_t>(m);
484
2/2
✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
20 if (rank == 0) {
485 10 PadToFlatOnRoot(a_in, n0, m, a_full);
486 10 PadToFlatOnRoot(b_in, n0, m, b_full);
487 } else {
488 10 a_full->assign(total, 0.0);
489 10 b_full->assign(total, 0.0);
490 }
491
492 20 MPI_Bcast(a_full->data(), static_cast<int>(a_full->size()), MPI_DOUBLE, 0, MPI_COMM_WORLD);
493 20 MPI_Bcast(b_full->data(), static_cast<int>(b_full->size()), MPI_DOUBLE, 0, MPI_COMM_WORLD);
494 20 }
495
496 20 void SplitBlocks(const Flat &a_full, const Flat &b_full, int m, int k, Flat *a11, Flat *a12, Flat *a21, Flat *a22,
497 Flat *b11, Flat *b12, Flat *b21, Flat *b22) {
498 20 ExtractBlock(a_full, m, 0, 0, k, a11);
499 20 ExtractBlock(a_full, m, 0, 1, k, a12);
500 20 ExtractBlock(a_full, m, 1, 0, k, a21);
501 20 ExtractBlock(a_full, m, 1, 1, k, a22);
502
503 20 ExtractBlock(b_full, m, 0, 0, k, b11);
504 20 ExtractBlock(b_full, m, 0, 1, k, b12);
505 20 ExtractBlock(b_full, m, 1, 0, k, b21);
506 20 ExtractBlock(b_full, m, 1, 1, k, b22);
507 20 }
508
509 20 void ComputeLocalMi(int rank, int comm_size, int k, const Flat &a11, const Flat &a12, const Flat &a21, const Flat &a22,
510 const Flat &b11, const Flat &b12, const Flat &b21, const Flat &b22, Flat *m1_loc, Flat *m2_loc,
511 Flat *m3_loc, Flat *m4_loc, Flat *m5_loc, Flat *m6_loc, Flat *m7_loc) {
512
2/2
✓ Branch 0 taken 140 times.
✓ Branch 1 taken 20 times.
160 for (int task_id = 1; task_id <= 7; ++task_id) {
513 const int owner = OwnerForTask(task_id, comm_size);
514
2/2
✓ Branch 0 taken 70 times.
✓ Branch 1 taken 70 times.
140 if (rank != owner) {
515 70 continue;
516 }
517
518 70 Flat tmp_out;
519
1/2
✓ Branch 1 taken 70 times.
✗ Branch 2 not taken.
70 ComputeOneStrassenTask(task_id, a11, a12, a21, a22, b11, b12, b21, b22, k, &tmp_out);
520
521 if (task_id == 1) {
522 m1_loc->swap(tmp_out);
523 } else if (task_id == 2) {
524 m2_loc->swap(tmp_out);
525 } else if (task_id == 3) {
526 m3_loc->swap(tmp_out);
527 } else if (task_id == 4) {
528 m4_loc->swap(tmp_out);
529 } else if (task_id == 5) {
530 m5_loc->swap(tmp_out);
531 } else if (task_id == 6) {
532 m6_loc->swap(tmp_out);
533 } else {
534 m7_loc->swap(tmp_out);
535 }
536 }
537 20 }
538
539 20 void ReduceMiToRoot(std::size_t kk, const Flat &m1_loc, const Flat &m2_loc, const Flat &m3_loc, const Flat &m4_loc,
540 const Flat &m5_loc, const Flat &m6_loc, const Flat &m7_loc, Flat *m1, Flat *m2, Flat *m3, Flat *m4,
541 Flat *m5, Flat *m6, Flat *m7) {
542 20 MPI_Reduce(m1_loc.data(), m1->data(), static_cast<int>(kk), MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
543 20 MPI_Reduce(m2_loc.data(), m2->data(), static_cast<int>(kk), MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
544 20 MPI_Reduce(m3_loc.data(), m3->data(), static_cast<int>(kk), MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
545 20 MPI_Reduce(m4_loc.data(), m4->data(), static_cast<int>(kk), MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
546 20 MPI_Reduce(m5_loc.data(), m5->data(), static_cast<int>(kk), MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
547 20 MPI_Reduce(m6_loc.data(), m6->data(), static_cast<int>(kk), MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
548 20 MPI_Reduce(m7_loc.data(), m7->data(), static_cast<int>(kk), MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
549 20 }
550
551 20 void RootCropToFlat(int rank, int m, int n0, const Flat &c_full, Flat *c_crop_flat) {
552 20 const auto total = static_cast<std::size_t>(n0) * static_cast<std::size_t>(n0);
553
2/2
✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
20 if (rank != 0) {
554 10 c_crop_flat->assign(total, 0.0);
555 10 return;
556 }
557
558 10 c_crop_flat->assign(total, 0.0);
559
560 10 const auto mm = static_cast<std::size_t>(m);
561 const auto nn = static_cast<std::size_t>(n0);
562
563
2/2
✓ Branch 0 taken 39 times.
✓ Branch 1 taken 10 times.
49 for (int row = 0; row < n0; ++row) {
564
2/2
✓ Branch 0 taken 187 times.
✓ Branch 1 taken 39 times.
226 for (int col = 0; col < n0; ++col) {
565 187 const auto rr = static_cast<std::size_t>(row);
566 187 const auto cc = static_cast<std::size_t>(col);
567 187 (*c_crop_flat)[Idx(rr, cc, nn)] = c_full[Idx(rr, cc, mm)];
568 }
569 }
570 }
571
572 20 Matrix FlatToMatrix(int n0, const Flat &flat) {
573
1/2
✓ Branch 2 taken 20 times.
✗ Branch 3 not taken.
20 Matrix out(static_cast<std::size_t>(n0), std::vector<double>(static_cast<std::size_t>(n0), 0.0));
574 const auto nn = static_cast<std::size_t>(n0);
575
2/2
✓ Branch 0 taken 78 times.
✓ Branch 1 taken 20 times.
98 for (int row = 0; row < n0; ++row) {
576
2/2
✓ Branch 0 taken 374 times.
✓ Branch 1 taken 78 times.
452 for (int col = 0; col < n0; ++col) {
577 374 const auto rr = static_cast<std::size_t>(row);
578 374 const auto cc = static_cast<std::size_t>(col);
579 374 out[rr][cc] = flat[Idx(rr, cc, nn)];
580 }
581 }
582 20 return out;
583 }
584
585 } // namespace
586
587 24 bool SannikovIShtrassenAlgorithmMPI::RunImpl() {
588 const auto &input = GetInput();
589 const auto &mat_a = std::get<0>(input);
590 const auto &mat_b = std::get<1>(input);
591
592 24 int rank = 0;
593 24 int comm_size = 1;
594 24 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
595 24 MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
596
597 24 std::uint64_t n0_64 = 0U;
598
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
24 if (rank == 0) {
599 12 n0_64 = static_cast<std::uint64_t>(mat_a.size());
600 }
601 24 MPI_Bcast(&n0_64, 1, MPI_UINT64_T, 0, MPI_COMM_WORLD);
602
603
1/2
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
24 if (!SizeOkU64(n0_64)) {
604 return false;
605 }
606
607 24 const int n0 = static_cast<int>(n0_64);
608
609 24 const auto m_sz = NextPow2(static_cast<std::size_t>(n0));
610
1/2
✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
24 if (m_sz > static_cast<std::size_t>(std::numeric_limits<int>::max())) {
611 return false;
612 }
613 24 const int m = static_cast<int>(m_sz);
614
615
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 20 times.
24 if (m == 1) {
616 4 Matrix out;
617
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 (void)FastPath1x1(rank, mat_a, mat_b, &out);
618
1/2
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
4 GetOutput() = out;
619 return true;
620 4 }
621
622 20 Flat a_full;
623 20 Flat b_full;
624
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 BroadcastFullPads(rank, m, mat_a, mat_b, n0, &a_full, &b_full);
625
626 20 const int k = m / 2;
627
628 20 Flat a11;
629 20 Flat a12;
630 20 Flat a21;
631 20 Flat a22;
632 20 Flat b11;
633 20 Flat b12;
634 20 Flat b21;
635 20 Flat b22;
636
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 SplitBlocks(a_full, b_full, m, k, &a11, &a12, &a21, &a22, &b11, &b12, &b21, &b22);
637
638 20 const auto kk = static_cast<std::size_t>(k) * static_cast<std::size_t>(k);
639
640
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m1_loc(kk, 0.0);
641
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m2_loc(kk, 0.0);
642
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m3_loc(kk, 0.0);
643
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m4_loc(kk, 0.0);
644
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m5_loc(kk, 0.0);
645
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m6_loc(kk, 0.0);
646
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m7_loc(kk, 0.0);
647
648
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 ComputeLocalMi(rank, comm_size, k, a11, a12, a21, a22, b11, b12, b21, b22, &m1_loc, &m2_loc, &m3_loc, &m4_loc,
649 &m5_loc, &m6_loc, &m7_loc);
650
651
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m1(kk, 0.0);
652
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m2(kk, 0.0);
653
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m3(kk, 0.0);
654
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m4(kk, 0.0);
655
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m5(kk, 0.0);
656
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m6(kk, 0.0);
657
1/4
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
20 Flat m7(kk, 0.0);
658
659
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 ReduceMiToRoot(kk, m1_loc, m2_loc, m3_loc, m4_loc, m5_loc, m6_loc, m7_loc, &m1, &m2, &m3, &m4, &m5, &m6, &m7);
660
661 20 Flat c_full;
662
2/2
✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
20 if (rank == 0) {
663
1/2
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
10 AssembleCOnRoot(m1, m2, m3, m4, m5, m6, m7, m, k, &c_full);
664 }
665
666 20 Flat c_crop_flat;
667
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 RootCropToFlat(rank, m, n0, c_full, &c_crop_flat);
668
1/2
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
20 MPI_Bcast(c_crop_flat.data(), static_cast<int>(c_crop_flat.size()), MPI_DOUBLE, 0, MPI_COMM_WORLD);
669
670
2/6
✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
✓ Branch 5 taken 20 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
20 GetOutput() = FlatToMatrix(n0, c_crop_flat);
671
1/2
✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
20 return !GetOutput().empty();
672 }
673
674 24 bool SannikovIShtrassenAlgorithmMPI::PostProcessingImpl() {
675 24 return !GetOutput().empty();
676 }
677
678 } // namespace sannikov_i_shtrassen_algorithm
679