| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "matrix_band_multiplication/mpi/include/ops_mpi.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | |||
| 5 | #include <algorithm> | ||
| 6 | #include <array> | ||
| 7 | #include <cstddef> | ||
| 8 | #include <utility> | ||
| 9 | #include <vector> | ||
| 10 | |||
| 11 | #include "matrix_band_multiplication/common/include/common.hpp" | ||
| 12 | |||
| 13 | namespace matrix_band_multiplication { | ||
| 14 | |||
| 15 | namespace { | ||
| 16 | 18 | std::vector<int> BuildCounts(int total, int parts) { | |
| 17 | 18 | std::vector<int> counts(parts, 0); | |
| 18 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | if (parts <= 0) { |
| 19 | return counts; | ||
| 20 | } | ||
| 21 | 18 | const int base = total / parts; | |
| 22 | 18 | int remainder = total % parts; | |
| 23 |
2/2✓ Branch 0 taken 36 times.
✓ Branch 1 taken 18 times.
|
54 | for (int i = 0; i < parts; ++i) { |
| 24 |
4/4✓ Branch 0 taken 30 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 6 times.
✓ Branch 3 taken 30 times.
|
66 | counts[i] = base + (remainder > 0 ? 1 : 0); |
| 25 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 30 times.
|
36 | if (remainder > 0) { |
| 26 | 6 | --remainder; | |
| 27 | } | ||
| 28 | } | ||
| 29 | return counts; | ||
| 30 | } | ||
| 31 | |||
| 32 | 24 | std::vector<int> BuildDisplacements(const std::vector<int> &counts) { | |
| 33 | 24 | std::vector<int> displs(counts.size(), 0); | |
| 34 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 24 times.
|
48 | for (std::size_t i = 1; i < counts.size(); ++i) { |
| 35 | 24 | displs[i] = displs[i - 1] + counts[i - 1]; | |
| 36 | } | ||
| 37 | 24 | return displs; | |
| 38 | } | ||
| 39 | |||
| 40 | bool MatrixIsValid(const Matrix &matrix) { | ||
| 41 |
5/10✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
|
6 | return matrix.rows > 0 && matrix.cols > 0 && matrix.values.size() == matrix.rows * matrix.cols; |
| 42 | } | ||
| 43 | |||
| 44 | } // namespace | ||
| 45 | |||
| 46 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MatrixBandMultiplicationMpi::MatrixBandMultiplicationMpi(const InType &in) { |
| 47 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 48 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | GetInput() = in; |
| 49 | 6 | GetOutput() = Matrix{}; | |
| 50 | 6 | } | |
| 51 | |||
| 52 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | bool MatrixBandMultiplicationMpi::ValidationImpl() { |
| 53 | const auto &matrix_a = GetInput().a; | ||
| 54 | const auto &matrix_b = GetInput().b; | ||
| 55 | if (!MatrixIsValid(matrix_a) || !MatrixIsValid(matrix_b)) { | ||
| 56 | return false; | ||
| 57 | } | ||
| 58 | 6 | return matrix_a.cols == matrix_b.rows; | |
| 59 | } | ||
| 60 | |||
| 61 | 6 | bool MatrixBandMultiplicationMpi::PreProcessingImpl() { | |
| 62 | 6 | MPI_Comm_rank(MPI_COMM_WORLD, &rank_); | |
| 63 | 6 | MPI_Comm_size(MPI_COMM_WORLD, &world_size_); | |
| 64 | |||
| 65 | 6 | const auto &matrix_a = GetInput().a; | |
| 66 | 6 | const auto &matrix_b = GetInput().b; | |
| 67 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | if (!BroadcastDimensions(matrix_a, matrix_b)) { |
| 68 | return false; | ||
| 69 | } | ||
| 70 | |||
| 71 | 6 | PrepareRowDistribution(matrix_a); | |
| 72 | 6 | PrepareColumnDistribution(matrix_b); | |
| 73 | 6 | PrepareResultGatherInfo(); | |
| 74 | 6 | return true; | |
| 75 | } | ||
| 76 | |||
| 77 | 6 | bool MatrixBandMultiplicationMpi::BroadcastDimensions(const Matrix &matrix_a, const Matrix &matrix_b) { | |
| 78 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank_ == 0) { |
| 79 | 3 | rows_a_ = matrix_a.rows; | |
| 80 | 3 | cols_a_ = matrix_a.cols; | |
| 81 | 3 | rows_b_ = matrix_b.rows; | |
| 82 | 3 | cols_b_ = matrix_b.cols; | |
| 83 | } | ||
| 84 | |||
| 85 | 6 | std::array<std::size_t, 4> dims = {rows_a_, cols_a_, rows_b_, cols_b_}; | |
| 86 | 6 | MPI_Bcast(dims.data(), static_cast<int>(dims.size()), MPI_UNSIGNED_LONG_LONG, 0, MPI_COMM_WORLD); | |
| 87 | 6 | rows_a_ = dims[0]; | |
| 88 | 6 | cols_a_ = dims[1]; | |
| 89 | 6 | rows_b_ = dims[2]; | |
| 90 | 6 | cols_b_ = dims[3]; | |
| 91 | |||
| 92 |
4/8✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
|
6 | return rows_a_ > 0 && cols_a_ > 0 && rows_b_ > 0 && cols_b_ > 0; |
| 93 | } | ||
| 94 | |||
| 95 | 6 | void MatrixBandMultiplicationMpi::PrepareRowDistribution(const Matrix &matrix_a) { | |
| 96 | 6 | row_counts_ = BuildCounts(static_cast<int>(rows_a_), world_size_); | |
| 97 | 12 | row_displs_ = BuildDisplacements(row_counts_); | |
| 98 | |||
| 99 | 6 | std::vector<int> send_counts(row_counts_.size()); | |
| 100 | std::ranges::transform(row_counts_, send_counts.begin(), | ||
| 101 | 12 | [this](int rows) { return rows * static_cast<int>(cols_a_); }); | |
| 102 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | std::vector<int> send_displs = BuildDisplacements(send_counts); |
| 103 | |||
| 104 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | const double *a_ptr = rank_ == 0 ? matrix_a.values.data() : nullptr; |
| 105 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | const int local_elems = send_counts[rank_]; |
| 106 |
2/6✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
6 | local_a_.assign(static_cast<std::size_t>(local_elems), 0.0); |
| 107 | |||
| 108 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Scatterv(a_ptr, send_counts.data(), send_displs.data(), MPI_DOUBLE, local_a_.data(), local_elems, MPI_DOUBLE, 0, |
| 109 | MPI_COMM_WORLD); | ||
| 110 | 6 | } | |
| 111 | |||
| 112 | 6 | void MatrixBandMultiplicationMpi::PrepareColumnDistribution(const Matrix &matrix_b) { | |
| 113 | 6 | col_counts_ = BuildCounts(static_cast<int>(cols_b_), world_size_); | |
| 114 | 12 | col_displs_ = BuildDisplacements(col_counts_); | |
| 115 | 6 | max_cols_per_proc_ = ComputeMaxColumns(); | |
| 116 | |||
| 117 | 6 | const std::size_t stripe_capacity = rows_b_ * static_cast<std::size_t>(max_cols_per_proc_); | |
| 118 | 6 | current_b_.assign(stripe_capacity, 0.0); | |
| 119 | 6 | rotation_buffer_.assign(stripe_capacity, 0.0); | |
| 120 | |||
| 121 | 6 | std::vector<double> packed; | |
| 122 |
1/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
6 | std::vector<int> send_counts(world_size_, 0); |
| 123 |
1/4✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
6 | std::vector<int> send_displs(world_size_, 0); |
| 124 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | PreparePackedColumns(matrix_b, packed, send_counts, send_displs); |
| 125 | |||
| 126 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | const int recv_elements = col_counts_[rank_] * static_cast<int>(rows_b_); |
| 127 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | ScatterInitialStripe(packed, send_counts, send_displs, recv_elements); |
| 128 | |||
| 129 | 6 | stripe_owner_ = rank_; | |
| 130 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | current_cols_ = col_counts_[rank_]; |
| 131 | 6 | } | |
| 132 | |||
| 133 | ✗ | int MatrixBandMultiplicationMpi::ComputeMaxColumns() const { | |
| 134 | ✗ | int max_cols = 0; | |
| 135 |
2/4✗ Branch 0 not taken.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 6 times.
|
18 | for (int count : col_counts_) { |
| 136 | ✗ | max_cols = std::max(max_cols, count); | |
| 137 | } | ||
| 138 | ✗ | return max_cols; | |
| 139 | } | ||
| 140 | |||
| 141 | 6 | void MatrixBandMultiplicationMpi::PreparePackedColumns(const Matrix &matrix_b, std::vector<double> &packed, | |
| 142 | std::vector<int> &send_counts, | ||
| 143 | std::vector<int> &send_displs) const { | ||
| 144 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank_ != 0) { |
| 145 | return; | ||
| 146 | } | ||
| 147 | |||
| 148 | 3 | packed.reserve(matrix_b.values.size()); | |
| 149 | int offset = 0; | ||
| 150 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
|
9 | for (int owner = 0; owner < world_size_; ++owner) { |
| 151 | 6 | const int cols = col_counts_[owner]; | |
| 152 | 6 | const int elems = cols * static_cast<int>(rows_b_); | |
| 153 | 6 | send_counts[owner] = elems; | |
| 154 | 6 | send_displs[owner] = offset; | |
| 155 | 6 | const int col_start = col_displs_[owner]; | |
| 156 |
2/2✓ Branch 0 taken 14 times.
✓ Branch 1 taken 6 times.
|
20 | for (std::size_t row = 0; row < rows_b_; ++row) { |
| 157 | 14 | const std::size_t base = row * cols_b_; | |
| 158 |
2/2✓ Branch 0 taken 16 times.
✓ Branch 1 taken 14 times.
|
30 | for (int col = 0; col < cols; ++col) { |
| 159 |
1/2✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
|
16 | const std::size_t src_index = base + static_cast<std::size_t>(col_start + col); |
| 160 | packed.push_back(matrix_b.values[src_index]); | ||
| 161 | } | ||
| 162 | } | ||
| 163 | 6 | offset += elems; | |
| 164 | } | ||
| 165 | } | ||
| 166 | |||
| 167 | 6 | void MatrixBandMultiplicationMpi::ScatterInitialStripe(const std::vector<double> &packed, | |
| 168 | const std::vector<int> &send_counts, | ||
| 169 | const std::vector<int> &send_displs, int recv_elements) { | ||
| 170 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | const double *send_buffer = rank_ == 0 ? packed.data() : nullptr; |
| 171 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | const int *counts_ptr = rank_ == 0 ? send_counts.data() : nullptr; |
| 172 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | const int *displs_ptr = rank_ == 0 ? send_displs.data() : nullptr; |
| 173 | |||
| 174 | 6 | MPI_Scatterv(send_buffer, counts_ptr, displs_ptr, MPI_DOUBLE, current_b_.data(), recv_elements, MPI_DOUBLE, 0, | |
| 175 | MPI_COMM_WORLD); | ||
| 176 | 6 | } | |
| 177 | |||
| 178 | 6 | void MatrixBandMultiplicationMpi::PrepareResultGatherInfo() { | |
| 179 | 6 | const auto local_elements = static_cast<std::size_t>(row_counts_[rank_]) * cols_b_; | |
| 180 | 6 | local_result_.assign(local_elements, 0.0); | |
| 181 | |||
| 182 | 6 | result_counts_ = BuildCounts(static_cast<int>(rows_a_), world_size_); | |
| 183 | 12 | result_displs_ = BuildDisplacements(result_counts_); | |
| 184 | |||
| 185 | std::ranges::transform(result_counts_, result_counts_.begin(), | ||
| 186 | 12 | [this](int rows) { return rows * static_cast<int>(cols_b_); }); | |
| 187 | std::ranges::transform(result_displs_, result_displs_.begin(), | ||
| 188 | 12 | [this](int rows_prefix) { return rows_prefix * static_cast<int>(cols_b_); }); | |
| 189 | 6 | } | |
| 190 | |||
| 191 | 12 | void MatrixBandMultiplicationMpi::MultiplyStripe(const double *stripe_data, int stripe_cols, int stripe_offset, | |
| 192 | int local_rows) { | ||
| 193 |
1/2✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
|
12 | if (stripe_cols == 0 || local_rows == 0) { |
| 194 | return; | ||
| 195 | } | ||
| 196 | |||
| 197 |
2/2✓ Branch 0 taken 14 times.
✓ Branch 1 taken 12 times.
|
26 | for (int row = 0; row < local_rows; ++row) { |
| 198 |
2/2✓ Branch 0 taken 17 times.
✓ Branch 1 taken 14 times.
|
31 | for (int col = 0; col < stripe_cols; ++col) { |
| 199 | double sum = 0.0; | ||
| 200 |
2/2✓ Branch 0 taken 38 times.
✓ Branch 1 taken 17 times.
|
55 | for (std::size_t k = 0; k < cols_a_; ++k) { |
| 201 | 38 | const std::size_t a_idx = (static_cast<std::size_t>(row) * cols_a_) + k; | |
| 202 | 38 | const std::size_t b_idx = (k * static_cast<std::size_t>(stripe_cols)) + static_cast<std::size_t>(col); | |
| 203 | 38 | sum += local_a_[a_idx] * stripe_data[b_idx]; | |
| 204 | } | ||
| 205 | 17 | const std::size_t result_idx = | |
| 206 | 17 | (static_cast<std::size_t>(row) * cols_b_) + static_cast<std::size_t>(stripe_offset + col); | |
| 207 | 17 | local_result_[result_idx] = sum; | |
| 208 | } | ||
| 209 | } | ||
| 210 | } | ||
| 211 | |||
| 212 | 6 | bool MatrixBandMultiplicationMpi::RunImpl() { | |
| 213 | 6 | const int local_rows = row_counts_[rank_]; | |
| 214 | 6 | const int total_steps = world_size_; | |
| 215 | |||
| 216 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 6 times.
|
18 | for (int step = 0; step < total_steps; ++step) { |
| 217 |
1/2✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
|
12 | const int stripe_offset = col_displs_.empty() ? 0 : col_displs_[stripe_owner_]; |
| 218 | 12 | MultiplyStripe(current_b_.data(), current_cols_, stripe_offset, local_rows); | |
| 219 | |||
| 220 |
3/4✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✓ Branch 3 taken 6 times.
|
12 | if (world_size_ == 1 || step == total_steps - 1) { |
| 221 | 6 | continue; | |
| 222 | } | ||
| 223 | |||
| 224 | // Rotate column stripes so every rank multiplies against each subset of B. | ||
| 225 | 6 | const int send_to = (rank_ - 1 + world_size_) % world_size_; | |
| 226 | 6 | const int recv_from = (rank_ + 1) % world_size_; | |
| 227 | 6 | const int send_cols = current_cols_; | |
| 228 | 6 | int recv_cols = 0; | |
| 229 | 6 | MPI_Sendrecv(&send_cols, 1, MPI_INT, send_to, 0, &recv_cols, 1, MPI_INT, recv_from, 0, MPI_COMM_WORLD, | |
| 230 | MPI_STATUS_IGNORE); | ||
| 231 | |||
| 232 | 6 | const int send_elements = send_cols * static_cast<int>(rows_b_); | |
| 233 | 6 | const int recv_elements = recv_cols * static_cast<int>(rows_b_); | |
| 234 | 6 | MPI_Sendrecv(current_b_.data(), send_elements, MPI_DOUBLE, send_to, 1, rotation_buffer_.data(), recv_elements, | |
| 235 | MPI_DOUBLE, recv_from, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE); | ||
| 236 | |||
| 237 | std::swap(current_b_, rotation_buffer_); | ||
| 238 | 6 | current_cols_ = recv_cols; | |
| 239 | 6 | stripe_owner_ = (stripe_owner_ + 1) % world_size_; | |
| 240 | } | ||
| 241 | |||
| 242 | 6 | return true; | |
| 243 | } | ||
| 244 | |||
| 245 | 6 | bool MatrixBandMultiplicationMpi::PostProcessingImpl() { | |
| 246 | 6 | std::vector<double> gathered; | |
| 247 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank_ == 0) { |
| 248 |
1/2✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
|
3 | gathered.resize(rows_a_ * cols_b_); |
| 249 | } | ||
| 250 | |||
| 251 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | const int local_result_elements = row_counts_[rank_] * static_cast<int>(cols_b_); |
| 252 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Gatherv(local_result_.data(), local_result_elements, MPI_DOUBLE, gathered.data(), result_counts_.data(), |
| 253 | result_displs_.data(), MPI_DOUBLE, 0, MPI_COMM_WORLD); | ||
| 254 | |||
| 255 | auto &output = GetOutput(); | ||
| 256 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
|
6 | if (rank_ == 0) { |
| 257 | 3 | output.rows = rows_a_; | |
| 258 | 3 | output.cols = cols_b_; | |
| 259 | 3 | output.values = std::move(gathered); | |
| 260 | } | ||
| 261 | |||
| 262 | 6 | std::array<std::size_t, 2> dims = {output.rows, output.cols}; | |
| 263 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Bcast(dims.data(), static_cast<int>(dims.size()), MPI_UNSIGNED_LONG_LONG, 0, MPI_COMM_WORLD); |
| 264 | 6 | output.rows = dims[0]; | |
| 265 | 6 | output.cols = dims[1]; | |
| 266 | |||
| 267 | 6 | const std::size_t total_size = output.rows * output.cols; | |
| 268 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | output.values.resize(total_size); |
| 269 |
1/2✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
|
6 | MPI_Bcast(output.values.data(), static_cast<int>(total_size), MPI_DOUBLE, 0, MPI_COMM_WORLD); |
| 270 | 6 | return true; | |
| 271 | } | ||
| 272 | |||
| 273 | } // namespace matrix_band_multiplication | ||
| 274 |