| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "bortsova_a_transmission_gather/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <cstddef> | ||
| 6 | #include <utility> | ||
| 7 | #include <vector> | ||
| 8 | |||
| 9 | #include "bortsova_a_transmission_gather/common/include/common.hpp" | ||
| 10 | |||
| 11 | namespace bortsova_a_transmission_gather { | ||
| 12 | |||
| 13 | namespace { | ||
| 14 | |||
| 15 | 3 | void CopyReceivedData(std::vector<double> &gather_buffer, const std::vector<double> &recv_buffer, | |
| 16 | const std::vector<int> &flags_int, std::vector<bool> &received, int world_size, int local_count) { | ||
| 17 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | for (int rank = 0; rank < world_size; ++rank) { |
| 18 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (flags_int[static_cast<std::size_t>(rank)] != 0) { |
| 19 | 3 | std::size_t start_idx = static_cast<std::size_t>(rank) * static_cast<std::size_t>(local_count); | |
| 20 |
2/2✓ Branch 0 taken 15 times.
✓ Branch 1 taken 3 times.
|
18 | for (int jj = 0; jj < local_count; ++jj) { |
| 21 | 15 | gather_buffer[start_idx + static_cast<std::size_t>(jj)] = recv_buffer[start_idx + static_cast<std::size_t>(jj)]; | |
| 22 | } | ||
| 23 | received[static_cast<std::size_t>(rank)] = true; | ||
| 24 | } | ||
| 25 | } | ||
| 26 | 3 | } | |
| 27 | |||
| 28 | void PrepareFlags(std::vector<int> &flags_int, const std::vector<bool> &received, int world_size) { | ||
| 29 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | for (int rank = 0; rank < world_size; ++rank) { |
| 30 | 6 | flags_int[static_cast<std::size_t>(rank)] = static_cast<int>(received[static_cast<std::size_t>(rank)]); | |
| 31 | } | ||
| 32 | } | ||
| 33 | |||
| 34 | } // namespace | ||
| 35 | |||
| 36 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | BortsovaATransmissionGatherMPI::BortsovaATransmissionGatherMPI(const InType &in) { |
| 37 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 38 | GetInput() = in; | ||
| 39 | 6 | GetOutput() = OutType{}; | |
| 40 | 6 | } | |
| 41 | |||
| 42 | 6 | bool BortsovaATransmissionGatherMPI::ValidationImpl() { | |
| 43 | 6 | MPI_Comm_rank(MPI_COMM_WORLD, &world_rank_); | |
| 44 | 6 | MPI_Comm_size(MPI_COMM_WORLD, &world_size_); | |
| 45 | |||
| 46 | 6 | int root = GetInput().root; | |
| 47 |
2/4✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
|
6 | return root >= 0 && root < world_size_; |
| 48 | } | ||
| 49 | |||
| 50 | 6 | bool BortsovaATransmissionGatherMPI::PreProcessingImpl() { | |
| 51 | 6 | send_count_ = static_cast<int>(GetInput().send_data.size()); | |
| 52 | 6 | return true; | |
| 53 | } | ||
| 54 | |||
| 55 | 6 | bool BortsovaATransmissionGatherMPI::RunImpl() { | |
| 56 | 6 | int root = GetInput().root; | |
| 57 | const std::vector<double> &send_data = GetInput().send_data; | ||
| 58 | |||
| 59 | 6 | int local_count = send_count_; | |
| 60 | 6 | MPI_Bcast(&local_count, 1, MPI_INT, root, MPI_COMM_WORLD); | |
| 61 | |||
| 62 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
|
6 | if (local_count == 0) { |
| 63 | GetOutput().recv_data.clear(); | ||
| 64 | ✗ | return true; | |
| 65 | } | ||
| 66 | |||
| 67 | 6 | std::size_t total_size = static_cast<std::size_t>(world_size_) * static_cast<std::size_t>(local_count); | |
| 68 | 6 | std::vector<double> gather_buffer(total_size, 0.0); | |
| 69 |
1/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
6 | std::vector<bool> received(static_cast<std::size_t>(world_size_), false); |
| 70 | |||
| 71 | 6 | std::size_t offset = static_cast<std::size_t>(world_rank_) * static_cast<std::size_t>(local_count); | |
| 72 |
2/2✓ Branch 0 taken 30 times.
✓ Branch 1 taken 6 times.
|
36 | for (std::size_t idx = 0; idx < send_data.size(); ++idx) { |
| 73 | 30 | gather_buffer[offset + idx] = send_data[idx]; | |
| 74 | } | ||
| 75 | received[static_cast<std::size_t>(world_rank_)] = true; | ||
| 76 | |||
| 77 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | TreeGather(gather_buffer, received, local_count, static_cast<int>(total_size)); |
| 78 | |||
| 79 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | TransferToRoot(gather_buffer, root, static_cast<int>(total_size)); |
| 80 | |||
| 81 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Bcast(gather_buffer.data(), static_cast<int>(total_size), MPI_DOUBLE, root, MPI_COMM_WORLD); |
| 82 | |||
| 83 | 6 | GetOutput().recv_data = std::move(gather_buffer); | |
| 84 | |||
| 85 | return true; | ||
| 86 | } | ||
| 87 | |||
| 88 | 6 | void BortsovaATransmissionGatherMPI::TreeGather(std::vector<double> &gather_buffer, std::vector<bool> &received, | |
| 89 | int local_count, int total_size) { | ||
| 90 | int step = 1; | ||
| 91 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | while (step < world_size_) { |
| 92 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if ((world_rank_ % (2 * step)) == 0) { |
| 93 | 3 | int source = world_rank_ + step; | |
| 94 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if (source < world_size_) { |
| 95 | 3 | ReceiveFromChild(gather_buffer, received, source, local_count, total_size); | |
| 96 | } | ||
| 97 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | } else if ((world_rank_ % step) == 0) { |
| 98 | 3 | SendToParent(gather_buffer, received, step, total_size); | |
| 99 | 3 | break; | |
| 100 | } | ||
| 101 | step *= 2; | ||
| 102 | } | ||
| 103 | 6 | } | |
| 104 | |||
| 105 | 3 | void BortsovaATransmissionGatherMPI::ReceiveFromChild(std::vector<double> &gather_buffer, std::vector<bool> &received, | |
| 106 | int source, int local_count, int total_size) const { | ||
| 107 | 3 | std::vector<double> recv_buffer(static_cast<std::size_t>(total_size), 0.0); | |
| 108 |
2/6✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
3 | std::vector<int> flags_int(static_cast<std::size_t>(world_size_), 0); |
| 109 | |||
| 110 | MPI_Status status; | ||
| 111 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | MPI_Recv(recv_buffer.data(), total_size, MPI_DOUBLE, source, 0, MPI_COMM_WORLD, &status); |
| 112 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | MPI_Recv(flags_int.data(), world_size_, MPI_INT, source, 1, MPI_COMM_WORLD, &status); |
| 113 | |||
| 114 | 3 | CopyReceivedData(gather_buffer, recv_buffer, flags_int, received, world_size_, local_count); | |
| 115 | 3 | } | |
| 116 | |||
| 117 | 3 | void BortsovaATransmissionGatherMPI::SendToParent(std::vector<double> &gather_buffer, std::vector<bool> &received, | |
| 118 | int step, int total_size) const { | ||
| 119 | 3 | int dest = world_rank_ - step; | |
| 120 | 3 | MPI_Send(gather_buffer.data(), total_size, MPI_DOUBLE, dest, 0, MPI_COMM_WORLD); | |
| 121 | |||
| 122 | 3 | std::vector<int> flags_int(static_cast<std::size_t>(world_size_), 0); | |
| 123 | 3 | PrepareFlags(flags_int, received, world_size_); | |
| 124 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | MPI_Send(flags_int.data(), world_size_, MPI_INT, dest, 1, MPI_COMM_WORLD); |
| 125 | 3 | } | |
| 126 | |||
| 127 | 6 | void BortsovaATransmissionGatherMPI::TransferToRoot(std::vector<double> &gather_buffer, int root, | |
| 128 | int total_size) const { | ||
| 129 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
6 | if (world_rank_ == 0 && root != 0) { |
| 130 | ✗ | MPI_Send(gather_buffer.data(), total_size, MPI_DOUBLE, root, 2, MPI_COMM_WORLD); | |
| 131 |
3/4✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
|
6 | } else if (world_rank_ == root && root != 0) { |
| 132 | MPI_Status status; | ||
| 133 | ✗ | MPI_Recv(gather_buffer.data(), total_size, MPI_DOUBLE, 0, 2, MPI_COMM_WORLD, &status); | |
| 134 | } | ||
| 135 | 6 | } | |
| 136 | |||
| 137 | 6 | bool BortsovaATransmissionGatherMPI::PostProcessingImpl() { | |
| 138 | 6 | return true; | |
| 139 | } | ||
| 140 | |||
| 141 | } // namespace bortsova_a_transmission_gather | ||
| 142 |