| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "romanov_a_crs_product/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <array> | ||
| 7 | #include <cstdint> | ||
| 8 | #include <utility> | ||
| 9 | #include <vector> | ||
| 10 | |||
| 11 | #include "romanov_a_crs_product/common/include/common.hpp" | ||
| 12 | |||
| 13 | namespace romanov_a_crs_product { | ||
| 14 | |||
| 15 |
1/2✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
|
10 | RomanovACRSProductMPI::RomanovACRSProductMPI(const InType &in) { |
| 16 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 17 | GetInput() = in; | ||
| 18 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | GetOutput() = CRS(static_cast<uint64_t>(0)); |
| 19 | 10 | } | |
| 20 | |||
| 21 | 10 | bool RomanovACRSProductMPI::ValidationImpl() { | |
| 22 | 10 | return (std::get<0>(GetInput()).GetCols() == std::get<1>(GetInput()).GetRows()); | |
| 23 | } | ||
| 24 | |||
| 25 | 10 | bool RomanovACRSProductMPI::PreProcessingImpl() { | |
| 26 | 10 | return true; | |
| 27 | } | ||
| 28 | |||
| 29 | namespace { | ||
| 30 | |||
| 31 | 20 | void BroadcastCRS(CRS &m, int root, MPI_Comm comm) { | |
| 32 | 20 | int rank = 0; | |
| 33 | 20 | MPI_Comm_rank(comm, &rank); | |
| 34 | |||
| 35 | 20 | std::array<uint64_t, 2> dims{m.n, m.m}; | |
| 36 | 20 | MPI_Bcast(dims.data(), dims.size(), MPI_UINT64_T, root, comm); | |
| 37 | |||
| 38 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank != root) { |
| 39 | 10 | m.n = dims[0]; | |
| 40 | 10 | m.m = dims[1]; | |
| 41 | 10 | m.row_index.resize(m.n + 1); | |
| 42 | } | ||
| 43 | |||
| 44 | 20 | uint64_t nnz = m.Nnz(); | |
| 45 | 20 | MPI_Bcast(&nnz, 1, MPI_UINT64_T, root, comm); | |
| 46 | |||
| 47 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 10 times.
|
20 | if (rank != root) { |
| 48 | 10 | m.value.resize(nnz); | |
| 49 | 10 | m.column.resize(nnz); | |
| 50 | } | ||
| 51 | |||
| 52 |
1/2✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
|
20 | if (nnz > 0) { |
| 53 | 20 | MPI_Bcast(m.value.data(), static_cast<int>(nnz), MPI_DOUBLE, root, comm); | |
| 54 |
1/2✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
|
20 | if (nnz > 0) { |
| 55 | 20 | MPI_Bcast(reinterpret_cast<uint64_t *>(m.column.data()), static_cast<int>(nnz), MPI_UINT64_T, root, comm); | |
| 56 | } | ||
| 57 | } | ||
| 58 | |||
| 59 |
1/2✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
|
20 | if (m.n + 1 > 0) { |
| 60 | 20 | MPI_Bcast(reinterpret_cast<uint64_t *>(m.row_index.data()), static_cast<int>(m.n + 1), MPI_UINT64_T, root, comm); | |
| 61 | } | ||
| 62 | 20 | } | |
| 63 | |||
| 64 | 10 | void SendCRS(const CRS &m, int dest, int tag, MPI_Comm comm) { | |
| 65 | 10 | std::array<uint64_t, 2> dims{m.n, m.m}; | |
| 66 | 10 | MPI_Send(dims.data(), dims.size(), MPI_UINT64_T, dest, tag, comm); | |
| 67 | |||
| 68 | 10 | uint64_t nnz = m.Nnz(); | |
| 69 | 10 | MPI_Send(&nnz, 1, MPI_UINT64_T, dest, tag + 1, comm); | |
| 70 | |||
| 71 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2 times.
|
10 | if (nnz > 0) { |
| 72 | 8 | MPI_Send(m.value.data(), static_cast<int>(nnz), MPI_DOUBLE, dest, tag + 2, comm); | |
| 73 |
1/2✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
|
8 | if (nnz > 0) { |
| 74 | 8 | MPI_Send(reinterpret_cast<const uint64_t *>(m.column.data()), static_cast<int>(nnz), MPI_UINT64_T, dest, tag + 3, | |
| 75 | comm); | ||
| 76 | } | ||
| 77 | } | ||
| 78 | |||
| 79 |
1/2✓ Branch 0 taken 10 times.
✗ Branch 1 not taken.
|
10 | if (m.n + 1 > 0) { |
| 80 | 10 | MPI_Send(reinterpret_cast<const uint64_t *>(m.row_index.data()), static_cast<int>(m.n + 1), MPI_UINT64_T, dest, | |
| 81 | tag + 4, comm); | ||
| 82 | } | ||
| 83 | 10 | } | |
| 84 | |||
| 85 | 10 | void RecvCRS(CRS &m, int src, int tag, MPI_Comm comm) { | |
| 86 | 10 | std::array<uint64_t, 2> dims{}; | |
| 87 | 10 | MPI_Recv(dims.data(), dims.size(), MPI_UINT64_T, src, tag, comm, MPI_STATUS_IGNORE); | |
| 88 | 10 | m.n = dims[0]; | |
| 89 | 10 | m.m = dims[1]; | |
| 90 | |||
| 91 | 10 | uint64_t nnz = 0; | |
| 92 | 10 | MPI_Recv(&nnz, 1, MPI_UINT64_T, src, tag + 1, comm, MPI_STATUS_IGNORE); | |
| 93 | |||
| 94 | 10 | m.value.resize(nnz); | |
| 95 | 10 | m.column.resize(nnz); | |
| 96 | 10 | m.row_index.resize(m.n + 1); | |
| 97 | |||
| 98 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2 times.
|
10 | if (nnz > 0) { |
| 99 | 8 | MPI_Recv(m.value.data(), static_cast<int>(nnz), MPI_DOUBLE, src, tag + 2, comm, MPI_STATUS_IGNORE); | |
| 100 |
1/2✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
|
8 | if (nnz > 0) { |
| 101 | 8 | MPI_Recv(reinterpret_cast<uint64_t *>(m.column.data()), static_cast<int>(nnz), MPI_UINT64_T, src, tag + 3, comm, | |
| 102 | MPI_STATUS_IGNORE); | ||
| 103 | } | ||
| 104 | } | ||
| 105 | |||
| 106 |
1/2✓ Branch 0 taken 10 times.
✗ Branch 1 not taken.
|
10 | if (m.n + 1 > 0) { |
| 107 | 10 | MPI_Recv(reinterpret_cast<uint64_t *>(m.row_index.data()), static_cast<int>(m.n + 1), MPI_UINT64_T, src, tag + 4, | |
| 108 | comm, MPI_STATUS_IGNORE); | ||
| 109 | } | ||
| 110 | 10 | } | |
| 111 | |||
| 112 | } // namespace | ||
| 113 | |||
| 114 | 10 | bool RomanovACRSProductMPI::RunImpl() { | |
| 115 | 10 | int rank = 0; | |
| 116 | 10 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 117 | |||
| 118 | 10 | int num_processes = 1; | |
| 119 | 10 | MPI_Comm_size(MPI_COMM_WORLD, &num_processes); | |
| 120 | |||
| 121 | 10 | CRS a; | |
| 122 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | CRS b; |
| 123 | |||
| 124 |
2/2✓ Branch 0 taken 5 times.
✓ Branch 1 taken 5 times.
|
10 | if (rank == 0) { |
| 125 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | a = std::get<0>(GetInput()); |
| 126 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | b = std::get<1>(GetInput()); |
| 127 | } | ||
| 128 | |||
| 129 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | BroadcastCRS(b, 0, MPI_COMM_WORLD); |
| 130 | |||
| 131 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | CRS a_local; |
| 132 | |||
| 133 |
2/2✓ Branch 0 taken 5 times.
✓ Branch 1 taken 5 times.
|
10 | if (rank == 0) { |
| 134 | 5 | uint64_t n = a.n; | |
| 135 | 5 | uint64_t rows_per_proc = (n + num_processes - 1) / num_processes; | |
| 136 | |||
| 137 |
2/2✓ Branch 0 taken 5 times.
✓ Branch 1 taken 5 times.
|
10 | for (int pn = 1; pn < num_processes; pn++) { |
| 138 | 5 | uint64_t start = pn * rows_per_proc; | |
| 139 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 4 times.
|
5 | if (start >= n) { |
| 140 |
1/2✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
|
1 | CRS empty(0, a.m); |
| 141 |
1/2✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
|
1 | SendCRS(empty, pn, 100 + pn, MPI_COMM_WORLD); |
| 142 | continue; | ||
| 143 | 1 | } | |
| 144 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | uint64_t end = std::min(n, start + rows_per_proc); |
| 145 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | CRS part = a.ExtractRows(start, end); |
| 146 |
1/2✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
|
4 | SendCRS(part, pn, 100 + pn, MPI_COMM_WORLD); |
| 147 | 4 | } | |
| 148 | |||
| 149 | uint64_t end0 = std::min(a.n, rows_per_proc); | ||
| 150 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | a_local = a.ExtractRows(0, end0); |
| 151 | |||
| 152 | } else { | ||
| 153 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | RecvCRS(a_local, 0, 100 + rank, MPI_COMM_WORLD); |
| 154 | } | ||
| 155 | |||
| 156 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | CRS c_local; |
| 157 |
3/4✓ Branch 0 taken 9 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
|
10 | if (a_local.n > 0 && b.n > 0) { |
| 158 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | c_local = a_local * b; |
| 159 | } else { | ||
| 160 |
1/2✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
|
1 | c_local = CRS(0, b.m); |
| 161 | } | ||
| 162 | |||
| 163 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | CRS c_total; |
| 164 | |||
| 165 |
2/2✓ Branch 0 taken 5 times.
✓ Branch 1 taken 5 times.
|
10 | if (rank == 0) { |
| 166 | 5 | std::vector<CRS> parts; | |
| 167 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | parts.reserve(static_cast<uint64_t>(num_processes)); |
| 168 | |||
| 169 | parts.push_back(std::move(c_local)); | ||
| 170 | |||
| 171 |
2/2✓ Branch 0 taken 5 times.
✓ Branch 1 taken 5 times.
|
10 | for (int pn = 1; pn < num_processes; pn++) { |
| 172 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | CRS temp; |
| 173 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | RecvCRS(temp, pn, 200 + pn, MPI_COMM_WORLD); |
| 174 | parts.push_back(std::move(temp)); | ||
| 175 | 5 | } | |
| 176 | |||
| 177 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | c_total = CRS::ConcatRows(parts); |
| 178 | |||
| 179 | 5 | } else { | |
| 180 |
1/2✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
|
5 | SendCRS(c_local, 0, 200 + rank, MPI_COMM_WORLD); |
| 181 | } | ||
| 182 | |||
| 183 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | BroadcastCRS(c_total, 0, MPI_COMM_WORLD); |
| 184 | |||
| 185 |
1/2✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
|
10 | GetOutput() = c_total; |
| 186 | |||
| 187 | 10 | return true; | |
| 188 | 10 | } | |
| 189 | |||
| 190 | 10 | bool RomanovACRSProductMPI::PostProcessingImpl() { | |
| 191 | 10 | return true; | |
| 192 | } | ||
| 193 | |||
| 194 | } // namespace romanov_a_crs_product | ||
| 195 |