| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "sakharov_a_cannon_algorithm/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <cstddef> | ||
| 6 | #include <vector> | ||
| 7 | |||
| 8 | #include "sakharov_a_cannon_algorithm/common/include/common.hpp" | ||
| 9 | |||
| 10 | namespace sakharov_a_cannon_algorithm { | ||
| 11 | |||
| 12 | namespace { | ||
| 13 | |||
| 14 | 16 | void LocalMultiply(const std::vector<double> &a_block, const std::vector<double> &b_block, std::vector<double> &c_block, | |
| 15 | int local_rows, int k_dim, int local_cols) { | ||
| 16 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 16 times.
|
36 | for (int ii = 0; ii < local_rows; ++ii) { |
| 17 |
2/2✓ Branch 0 taken 50 times.
✓ Branch 1 taken 20 times.
|
70 | for (int kk = 0; kk < k_dim; ++kk) { |
| 18 | 50 | double a_val = a_block[Idx(k_dim, ii, kk)]; | |
| 19 |
2/2✓ Branch 0 taken 156 times.
✓ Branch 1 taken 50 times.
|
206 | for (int jj = 0; jj < local_cols; ++jj) { |
| 20 | 156 | c_block[Idx(local_cols, ii, jj)] += a_val * b_block[Idx(local_cols, kk, jj)]; | |
| 21 | } | ||
| 22 | } | ||
| 23 | } | ||
| 24 | 16 | } | |
| 25 | |||
| 26 | } // namespace | ||
| 27 | |||
| 28 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | SakharovACannonAlgorithmMPI::SakharovACannonAlgorithmMPI(const InType &in) { |
| 29 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 30 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | GetInput() = in; |
| 31 | 16 | } | |
| 32 | |||
| 33 | 16 | bool SakharovACannonAlgorithmMPI::ValidationImpl() { | |
| 34 | 16 | return IsValidInput(GetInput()); | |
| 35 | } | ||
| 36 | |||
| 37 | 16 | bool SakharovACannonAlgorithmMPI::PreProcessingImpl() { | |
| 38 | const auto &input = GetInput(); | ||
| 39 | 16 | auto out_size = static_cast<std::size_t>(input.rows_a) * static_cast<std::size_t>(input.cols_b); | |
| 40 | 16 | GetOutput().assign(out_size, 0.0); | |
| 41 | 16 | return true; | |
| 42 | } | ||
| 43 | |||
| 44 | 16 | bool SakharovACannonAlgorithmMPI::RunImpl() { | |
| 45 | 16 | int rank = 0; | |
| 46 | 16 | int world_size = 0; | |
| 47 | 16 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 48 | 16 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); | |
| 49 | |||
| 50 | const auto &input = GetInput(); | ||
| 51 | 16 | const int m = input.rows_a; | |
| 52 | 16 | const int k = input.cols_a; | |
| 53 | 16 | const int n = input.cols_b; | |
| 54 | |||
| 55 | 16 | int base_rows = m / world_size; | |
| 56 | 16 | int extra_rows = m % world_size; | |
| 57 | |||
| 58 | 16 | std::vector<int> row_counts(world_size); | |
| 59 |
1/4✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
16 | std::vector<int> row_displs(world_size); |
| 60 | int offset = 0; | ||
| 61 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 16 times.
|
48 | for (int idx = 0; idx < world_size; ++idx) { |
| 62 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 8 times.
|
56 | row_counts[idx] = base_rows + (idx < extra_rows ? 1 : 0); |
| 63 | 32 | row_displs[idx] = offset; | |
| 64 | 32 | offset += row_counts[idx]; | |
| 65 | } | ||
| 66 | |||
| 67 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | int local_rows = row_counts[rank]; |
| 68 | |||
| 69 |
1/4✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
16 | std::vector<double> local_a(static_cast<std::size_t>(local_rows) * static_cast<std::size_t>(k)); |
| 70 |
1/4✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
16 | std::vector<double> local_c(static_cast<std::size_t>(local_rows) * static_cast<std::size_t>(n), 0.0); |
| 71 | |||
| 72 |
1/4✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
16 | std::vector<int> send_counts_a(world_size); |
| 73 |
1/4✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
16 | std::vector<int> displs_a(world_size); |
| 74 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 16 times.
|
48 | for (int idx = 0; idx < world_size; ++idx) { |
| 75 | 32 | send_counts_a[idx] = row_counts[idx] * k; | |
| 76 | 32 | displs_a[idx] = row_displs[idx] * k; | |
| 77 | } | ||
| 78 | |||
| 79 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | MPI_Scatterv(input.a.data(), send_counts_a.data(), displs_a.data(), MPI_DOUBLE, local_a.data(), local_rows * k, |
| 80 | MPI_DOUBLE, 0, MPI_COMM_WORLD); | ||
| 81 | |||
| 82 |
1/4✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
16 | std::vector<double> b_data(static_cast<std::size_t>(k) * static_cast<std::size_t>(n)); |
| 83 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
|
16 | if (rank == 0) { |
| 84 |
1/2✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
|
8 | b_data = input.b; |
| 85 | } | ||
| 86 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | MPI_Bcast(b_data.data(), k * n, MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 87 | |||
| 88 | 16 | LocalMultiply(local_a, b_data, local_c, local_rows, k, n); | |
| 89 | |||
| 90 |
1/4✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
16 | std::vector<int> recv_counts_c(world_size); |
| 91 |
1/4✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
16 | std::vector<int> displs_c(world_size); |
| 92 |
2/2✓ Branch 0 taken 32 times.
✓ Branch 1 taken 16 times.
|
48 | for (int idx = 0; idx < world_size; ++idx) { |
| 93 | 32 | recv_counts_c[idx] = row_counts[idx] * n; | |
| 94 | 32 | displs_c[idx] = row_displs[idx] * n; | |
| 95 | } | ||
| 96 | |||
| 97 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | MPI_Gatherv(local_c.data(), local_rows * n, MPI_DOUBLE, GetOutput().data(), recv_counts_c.data(), displs_c.data(), |
| 98 | MPI_DOUBLE, 0, MPI_COMM_WORLD); | ||
| 99 | |||
| 100 |
1/2✓ Branch 1 taken 16 times.
✗ Branch 2 not taken.
|
16 | MPI_Bcast(GetOutput().data(), m * n, MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 101 | |||
| 102 | 16 | return true; | |
| 103 | } | ||
| 104 | |||
| 105 | 16 | bool SakharovACannonAlgorithmMPI::PostProcessingImpl() { | |
| 106 | 16 | return true; | |
| 107 | } | ||
| 108 | |||
| 109 | } // namespace sakharov_a_cannon_algorithm | ||
| 110 |