| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "makoveeva_s_cannon_algorithm/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <array> | ||
| 6 | #include <cmath> | ||
| 7 | #include <cstddef> | ||
| 8 | #include <tuple> | ||
| 9 | #include <utility> | ||
| 10 | #include <vector> | ||
| 11 | |||
| 12 | #include "makoveeva_s_cannon_algorithm/common/include/common.hpp" | ||
| 13 | |||
| 14 | namespace makoveeva_s_cannon_algorithm { | ||
| 15 | |||
| 16 | namespace { | ||
| 17 | |||
| 18 | bool CheckInputLocal(const std::vector<double> &a, const std::vector<double> &b, int n) { | ||
| 19 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
|
9 | if (n <= 0) { |
| 20 | return false; | ||
| 21 | } | ||
| 22 | 9 | const auto n_sz = static_cast<std::size_t>(n); | |
| 23 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
|
9 | const auto expected = n_sz * n_sz; |
| 24 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 9 times.
|
9 | return (a.size() == expected) && (b.size() == expected); |
| 25 | } | ||
| 26 | |||
| 27 | 36 | int ChooseGridQ(int size, int n) { | |
| 28 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 36 times.
|
36 | if (size <= 0 || n <= 0) { |
| 29 | return 0; | ||
| 30 | } | ||
| 31 | 36 | int q = static_cast<int>(std::floor(std::sqrt(static_cast<double>(size)))); | |
| 32 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 36 times.
|
36 | while (q > 0) { |
| 33 |
2/4✓ Branch 0 taken 36 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 36 times.
|
36 | if ((q * q) <= size && (n % q == 0)) { |
| 34 | return q; | ||
| 35 | } | ||
| 36 | ✗ | --q; | |
| 37 | } | ||
| 38 | return 0; | ||
| 39 | } | ||
| 40 | |||
| 41 | 54 | int BroadcastN(const InType &input, int world_rank) { | |
| 42 | 54 | int n = 0; | |
| 43 |
2/2✓ Branch 0 taken 27 times.
✓ Branch 1 taken 27 times.
|
54 | if (world_rank == 0) { |
| 44 | 27 | n = std::get<2>(input); | |
| 45 | } | ||
| 46 | 54 | MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 47 | 54 | return n; | |
| 48 | } | ||
| 49 | |||
| 50 | 9 | MPI_Datatype MakeBlockType(int n, int bs) { | |
| 51 | 9 | MPI_Datatype block = MPI_DATATYPE_NULL; | |
| 52 | 9 | MPI_Type_vector(bs, bs, n, MPI_DOUBLE, &block); | |
| 53 | |||
| 54 | 9 | MPI_Datatype resized = MPI_DATATYPE_NULL; | |
| 55 | 9 | MPI_Type_create_resized(block, 0, static_cast<MPI_Aint>(sizeof(double)), &resized); | |
| 56 | |||
| 57 | 9 | MPI_Type_commit(&resized); | |
| 58 | 9 | MPI_Type_free(&block); | |
| 59 | 9 | return resized; | |
| 60 | } | ||
| 61 | |||
| 62 | 9 | void LocalMatMulAcc(const std::vector<double> &a, const std::vector<double> &b, int bs, std::vector<double> *c) { | |
| 63 | 9 | const auto bs_sz = static_cast<std::size_t>(bs); | |
| 64 |
2/2✓ Branch 0 taken 93 times.
✓ Branch 1 taken 9 times.
|
102 | for (int i = 0; i < bs; ++i) { |
| 65 | 93 | const auto i_sz = static_cast<std::size_t>(i); | |
| 66 |
2/2✓ Branch 0 taken 1823 times.
✓ Branch 1 taken 93 times.
|
1916 | for (int k = 0; k < bs; ++k) { |
| 67 | 1823 | const auto k_sz = static_cast<std::size_t>(k); | |
| 68 | 1823 | const double a_ik = a[(i_sz * bs_sz) + k_sz]; | |
| 69 |
2/2✓ Branch 0 taken 45819 times.
✓ Branch 1 taken 1823 times.
|
47642 | for (int j = 0; j < bs; ++j) { |
| 70 | 45819 | const auto j_sz = static_cast<std::size_t>(j); | |
| 71 | 45819 | (*c)[(i_sz * bs_sz) + j_sz] += a_ik * b[(k_sz * bs_sz) + j_sz]; | |
| 72 | } | ||
| 73 | } | ||
| 74 | } | ||
| 75 | 9 | } | |
| 76 | |||
| 77 | MPI_Comm MakeActiveComm(int world_rank, int active_p, bool *is_active) { | ||
| 78 | *is_active = (world_rank < active_p); | ||
| 79 | 18 | MPI_Comm active_comm = MPI_COMM_NULL; | |
| 80 | 18 | const int color = (*is_active) ? 0 : MPI_UNDEFINED; | |
| 81 | 18 | MPI_Comm_split(MPI_COMM_WORLD, color, world_rank, &active_comm); | |
| 82 | 18 | return active_comm; | |
| 83 | } | ||
| 84 | |||
| 85 | 9 | MPI_Comm MakeCartComm(MPI_Comm active_comm, int q) { | |
| 86 | 9 | const std::array<int, 2> dims = {q, q}; | |
| 87 | 9 | const std::array<int, 2> periods = {1, 1}; | |
| 88 | |||
| 89 | 9 | MPI_Comm cart_comm = MPI_COMM_NULL; | |
| 90 | 9 | MPI_Cart_create(active_comm, 2, dims.data(), periods.data(), 0 /*reorder*/, &cart_comm); | |
| 91 | 9 | return cart_comm; | |
| 92 | } | ||
| 93 | |||
| 94 | 9 | std::array<int, 2> GetCoords(MPI_Comm cart_comm) { | |
| 95 | 9 | int cart_rank = 0; | |
| 96 | 9 | MPI_Comm_rank(cart_comm, &cart_rank); | |
| 97 | |||
| 98 | 9 | std::array<int, 2> coords = {0, 0}; | |
| 99 | 9 | MPI_Cart_coords(cart_comm, cart_rank, 2, coords.data()); | |
| 100 | 9 | return coords; | |
| 101 | } | ||
| 102 | |||
| 103 | 9 | void BuildScatterMeta(int q, int bs, int n, int active_p, std::vector<int> *sendcounts, std::vector<int> *displs, | |
| 104 | MPI_Comm cart_comm) { | ||
| 105 | 9 | int cart_rank = 0; | |
| 106 | 9 | MPI_Comm_rank(cart_comm, &cart_rank); | |
| 107 | |||
| 108 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
|
9 | if (cart_rank != 0) { |
| 109 | ✗ | return; | |
| 110 | } | ||
| 111 | |||
| 112 | 9 | sendcounts->assign(static_cast<std::size_t>(active_p), 1); | |
| 113 | 9 | displs->assign(static_cast<std::size_t>(active_p), 0); | |
| 114 | |||
| 115 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | for (int row_idx = 0; row_idx < q; ++row_idx) { |
| 116 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | for (int col_idx = 0; col_idx < q; ++col_idx) { |
| 117 | 9 | const int proc = (row_idx * q) + col_idx; | |
| 118 | 9 | (*displs)[static_cast<std::size_t>(proc)] = ((row_idx * bs) * n) + (col_idx * bs); | |
| 119 | } | ||
| 120 | } | ||
| 121 | } | ||
| 122 | |||
| 123 | 9 | void ScatterBlocks(const double *a_full_ptr, const double *b_full_ptr, const int *sendcounts_ptr, const int *displs_ptr, | |
| 124 | MPI_Datatype block_type, int bs, MPI_Comm cart_comm, std::vector<double> *a_block, | ||
| 125 | std::vector<double> *b_block) { | ||
| 126 | 9 | MPI_Scatterv(a_full_ptr, sendcounts_ptr, displs_ptr, block_type, a_block->data(), bs * bs, MPI_DOUBLE, 0, cart_comm); | |
| 127 | 9 | MPI_Scatterv(b_full_ptr, sendcounts_ptr, displs_ptr, block_type, b_block->data(), bs * bs, MPI_DOUBLE, 0, cart_comm); | |
| 128 | 9 | } | |
| 129 | |||
| 130 | 9 | void InitialShift(MPI_Comm cart_comm, int row, int col, int bs, std::vector<double> *a_block, | |
| 131 | std::vector<double> *b_block) { | ||
| 132 | 9 | int src = 0; | |
| 133 | 9 | int dst = 0; | |
| 134 | |||
| 135 | 9 | MPI_Cart_shift(cart_comm, 1, -row, &src, &dst); | |
| 136 | 9 | MPI_Sendrecv_replace(a_block->data(), bs * bs, MPI_DOUBLE, dst, 0, src, 0, cart_comm, MPI_STATUS_IGNORE); | |
| 137 | |||
| 138 | 9 | MPI_Cart_shift(cart_comm, 0, -col, &src, &dst); | |
| 139 | 9 | MPI_Sendrecv_replace(b_block->data(), bs * bs, MPI_DOUBLE, dst, 1, src, 1, cart_comm, MPI_STATUS_IGNORE); | |
| 140 | 9 | } | |
| 141 | |||
| 142 | 9 | void CannonLoop(MPI_Comm cart_comm, int q, int bs, std::vector<double> *a_block, std::vector<double> *b_block, | |
| 143 | std::vector<double> *c_block) { | ||
| 144 | 9 | int src = 0; | |
| 145 | 9 | int dst = 0; | |
| 146 | |||
| 147 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | for (int step = 0; step < q; ++step) { |
| 148 | 9 | LocalMatMulAcc(*a_block, *b_block, bs, c_block); | |
| 149 | |||
| 150 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
|
9 | if (step + 1 < q) { |
| 151 | ✗ | MPI_Cart_shift(cart_comm, 1, -1, &src, &dst); | |
| 152 | ✗ | MPI_Sendrecv_replace(a_block->data(), bs * bs, MPI_DOUBLE, dst, 2, src, 2, cart_comm, MPI_STATUS_IGNORE); | |
| 153 | |||
| 154 | ✗ | MPI_Cart_shift(cart_comm, 0, -1, &src, &dst); | |
| 155 | ✗ | MPI_Sendrecv_replace(b_block->data(), bs * bs, MPI_DOUBLE, dst, 3, src, 3, cart_comm, MPI_STATUS_IGNORE); | |
| 156 | } | ||
| 157 | } | ||
| 158 | 9 | } | |
| 159 | |||
| 160 | 9 | std::vector<double> GatherResult(MPI_Comm cart_comm, int n, int bs, MPI_Datatype block_type, const int *sendcounts_ptr, | |
| 161 | const int *displs_ptr, const std::vector<double> &c_block) { | ||
| 162 | 9 | int cart_rank = 0; | |
| 163 | 9 | MPI_Comm_rank(cart_comm, &cart_rank); | |
| 164 | |||
| 165 | 9 | std::vector<double> c_full; | |
| 166 | double *recv_ptr = nullptr; | ||
| 167 | |||
| 168 |
1/2✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
|
9 | if (cart_rank == 0) { |
| 169 |
1/4✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
9 | c_full.assign(static_cast<std::size_t>(n) * static_cast<std::size_t>(n), 0.0); |
| 170 | recv_ptr = c_full.data(); | ||
| 171 | } | ||
| 172 | |||
| 173 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | MPI_Gatherv(c_block.data(), bs * bs, MPI_DOUBLE, recv_ptr, sendcounts_ptr, displs_ptr, block_type, 0, cart_comm); |
| 174 | 9 | return c_full; | |
| 175 | } | ||
| 176 | |||
| 177 | 18 | std::vector<double> BroadcastFullToAll(std::vector<double> c_full, int world_rank) { | |
| 178 | 18 | int out_size = 0; | |
| 179 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | if (world_rank == 0) { |
| 180 | 9 | out_size = static_cast<int>(c_full.size()); | |
| 181 | } | ||
| 182 | 18 | MPI_Bcast(&out_size, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 183 | |||
| 184 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | if (world_rank != 0) { |
| 185 | 9 | c_full.assign(static_cast<std::size_t>(out_size), 0.0); | |
| 186 | } | ||
| 187 | |||
| 188 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | if (out_size > 0) { |
| 189 | 18 | MPI_Bcast(c_full.data(), out_size, MPI_DOUBLE, 0, MPI_COMM_WORLD); | |
| 190 | } | ||
| 191 | 18 | return c_full; | |
| 192 | } | ||
| 193 | |||
| 194 | } // namespace | ||
| 195 | |||
| 196 |
1/2✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
|
18 | MakoveevaSCannonAlgorithmMPI::MakoveevaSCannonAlgorithmMPI(const InType &in) { |
| 197 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 198 | GetInput() = in; | ||
| 199 | 18 | GetOutput() = OutType{}; | |
| 200 | 18 | } | |
| 201 | |||
| 202 | 18 | bool MakoveevaSCannonAlgorithmMPI::ValidationImpl() { | |
| 203 | 18 | int world_rank = 0; | |
| 204 | 18 | int world_size = 0; | |
| 205 | 18 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); | |
| 206 | 18 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); | |
| 207 | |||
| 208 | 18 | const int n = BroadcastN(GetInput(), world_rank); | |
| 209 | |||
| 210 | 18 | int ok_input = 0; | |
| 211 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | if (world_rank == 0) { |
| 212 | const auto &a = std::get<0>(GetInput()); | ||
| 213 | const auto &b = std::get<1>(GetInput()); | ||
| 214 |
1/2✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
|
18 | ok_input = (GetOutput().empty() && CheckInputLocal(a, b, n)) ? 1 : 0; |
| 215 | } | ||
| 216 | 18 | MPI_Bcast(&ok_input, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 217 | |||
| 218 | 18 | const int q = ChooseGridQ(world_size, n); | |
| 219 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
|
18 | const int ok_grid = (q > 0) ? 1 : 0; |
| 220 | |||
| 221 |
2/4✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
|
18 | return (ok_input != 0) && (ok_grid != 0); |
| 222 | } | ||
| 223 | |||
| 224 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
|
18 | bool MakoveevaSCannonAlgorithmMPI::PreProcessingImpl() { |
| 225 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
|
18 | if (!GetOutput().empty()) { |
| 226 | GetOutput().clear(); | ||
| 227 | } | ||
| 228 | 18 | return true; | |
| 229 | } | ||
| 230 | |||
| 231 | 18 | bool MakoveevaSCannonAlgorithmMPI::RunImpl() { | |
| 232 | 18 | int world_rank = 0; | |
| 233 | 18 | int world_size = 0; | |
| 234 | 18 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); | |
| 235 | 18 | MPI_Comm_size(MPI_COMM_WORLD, &world_size); | |
| 236 | |||
| 237 | 18 | const int n = BroadcastN(GetInput(), world_rank); | |
| 238 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
|
18 | if (n <= 0) { |
| 239 | return false; | ||
| 240 | } | ||
| 241 | |||
| 242 | 18 | const int q = ChooseGridQ(world_size, n); | |
| 243 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
|
18 | if (q <= 0) { |
| 244 | return false; | ||
| 245 | } | ||
| 246 | |||
| 247 | 18 | const int active_p = q * q; | |
| 248 | bool is_active = false; | ||
| 249 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | MPI_Comm active_comm = MakeActiveComm(world_rank, active_p, &is_active); |
| 250 | |||
| 251 | 18 | std::vector<double> c_full; | |
| 252 | |||
| 253 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
|
18 | if (is_active) { |
| 254 | 9 | const int bs = n / q; | |
| 255 | 9 | const auto bs_sz = static_cast<std::size_t>(bs); | |
| 256 | |||
| 257 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | MPI_Comm cart_comm = MakeCartComm(active_comm, q); |
| 258 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | const std::array<int, 2> coords = GetCoords(cart_comm); |
| 259 | 9 | const int row = coords[0]; | |
| 260 | 9 | const int col = coords[1]; | |
| 261 | |||
| 262 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | std::vector<double> a_block(bs_sz * bs_sz, 0.0); |
| 263 |
1/4✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
9 | std::vector<double> b_block(bs_sz * bs_sz, 0.0); |
| 264 |
1/4✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
9 | std::vector<double> c_block(bs_sz * bs_sz, 0.0); |
| 265 | |||
| 266 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | MPI_Datatype block_type = MakeBlockType(n, bs); |
| 267 | |||
| 268 | 9 | std::vector<int> sendcounts; | |
| 269 | 9 | std::vector<int> displs; | |
| 270 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | BuildScatterMeta(q, bs, n, active_p, &sendcounts, &displs, cart_comm); |
| 271 | |||
| 272 | 9 | int cart_rank = 0; | |
| 273 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | MPI_Comm_rank(cart_comm, &cart_rank); |
| 274 |
1/2✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
|
9 | const int *sendcounts_ptr = (cart_rank == 0) ? sendcounts.data() : nullptr; |
| 275 |
1/2✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
|
9 | const int *displs_ptr = (cart_rank == 0) ? displs.data() : nullptr; |
| 276 | |||
| 277 | const double *a_full_ptr = nullptr; | ||
| 278 | const double *b_full_ptr = nullptr; | ||
| 279 |
1/2✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
|
9 | if (world_rank == 0) { |
| 280 | a_full_ptr = std::get<0>(GetInput()).data(); | ||
| 281 | b_full_ptr = std::get<1>(GetInput()).data(); | ||
| 282 | } | ||
| 283 | |||
| 284 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | ScatterBlocks(a_full_ptr, b_full_ptr, sendcounts_ptr, displs_ptr, block_type, bs, cart_comm, &a_block, &b_block); |
| 285 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | InitialShift(cart_comm, row, col, bs, &a_block, &b_block); |
| 286 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | CannonLoop(cart_comm, q, bs, &a_block, &b_block, &c_block); |
| 287 | |||
| 288 |
1/4✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
9 | c_full = GatherResult(cart_comm, n, bs, block_type, sendcounts_ptr, displs_ptr, c_block); |
| 289 | |||
| 290 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | MPI_Type_free(&block_type); |
| 291 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | MPI_Comm_free(&cart_comm); |
| 292 |
1/2✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
|
9 | MPI_Comm_free(&active_comm); |
| 293 | } | ||
| 294 | |||
| 295 |
1/4✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
36 | c_full = BroadcastFullToAll(std::move(c_full), world_rank); |
| 296 | |||
| 297 | GetOutput() = std::move(c_full); | ||
| 298 | return true; | ||
| 299 | } | ||
| 300 | |||
| 301 | 18 | bool MakoveevaSCannonAlgorithmMPI::PostProcessingImpl() { | |
| 302 | 18 | int world_rank = 0; | |
| 303 | 18 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); | |
| 304 | |||
| 305 | 18 | const int n = BroadcastN(GetInput(), world_rank); | |
| 306 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | if (n <= 0) { |
| 307 | return false; | ||
| 308 | } | ||
| 309 | |||
| 310 | 18 | const auto n_sz = static_cast<std::size_t>(n); | |
| 311 | 18 | return GetOutput().size() == (n_sz * n_sz); | |
| 312 | } | ||
| 313 | |||
| 314 | } // namespace makoveeva_s_cannon_algorithm | ||
| 315 |