| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | #include "karpich_i_bitwise_batcher/all/include/ops_all.hpp" | ||
| 2 | |||
| 3 | #include <mpi.h> | ||
| 4 | #include <omp.h> | ||
| 5 | |||
| 6 | #include <algorithm> | ||
| 7 | #include <random> | ||
| 8 | #include <utility> | ||
| 9 | #include <vector> | ||
| 10 | |||
| 11 | #include "karpich_i_bitwise_batcher/common/include/common.hpp" | ||
| 12 | |||
| 13 | namespace karpich_i_bitwise_batcher { | ||
| 14 | |||
| 15 | namespace { | ||
| 16 | |||
| 17 | int FindMaxParallel(const std::vector<int> &arr, int n) { | ||
| 18 | 30 | int max_val = arr[0]; | |
| 19 | 30 | #pragma omp parallel for default(none) shared(arr, n) reduction(max : max_val) | |
| 20 | for (int i = 1; i < n; i++) { | ||
| 21 | max_val = std::max(max_val, arr[i]); | ||
| 22 | } | ||
| 23 | return max_val; | ||
| 24 | } | ||
| 25 | |||
| 26 | 59 | void CountingPass(std::vector<int> &arr, std::vector<int> &buffer, int n, int shift) { | |
| 27 | 59 | std::vector<int> count(256, 0); | |
| 28 | 59 | int num_threads = omp_get_max_threads(); | |
| 29 |
2/6✓ Branch 1 taken 59 times.
✗ Branch 2 not taken.
✓ Branch 4 taken 59 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
|
59 | std::vector<std::vector<int>> local_counts(num_threads, std::vector<int>(256, 0)); |
| 30 | |||
| 31 | 59 | #pragma omp parallel default(none) shared(arr, shift, local_counts, n) | |
| 32 | { | ||
| 33 | int tid = omp_get_thread_num(); | ||
| 34 | #pragma omp for | ||
| 35 | for (int i = 0; i < n; i++) { | ||
| 36 | local_counts[tid][(arr[i] >> shift) & 0xFF]++; | ||
| 37 | } | ||
| 38 | } | ||
| 39 | |||
| 40 |
2/2✓ Branch 0 taken 118 times.
✓ Branch 1 taken 59 times.
|
177 | for (int ti = 0; ti < num_threads; ti++) { |
| 41 |
2/2✓ Branch 0 taken 30208 times.
✓ Branch 1 taken 118 times.
|
30326 | for (int i = 0; i < 256; i++) { |
| 42 | 30208 | count[i] += local_counts[ti][i]; | |
| 43 | } | ||
| 44 | } | ||
| 45 | |||
| 46 |
2/2✓ Branch 0 taken 15045 times.
✓ Branch 1 taken 59 times.
|
15104 | for (int i = 1; i < 256; i++) { |
| 47 | 15045 | count[i] += count[i - 1]; | |
| 48 | } | ||
| 49 | |||
| 50 |
2/2✓ Branch 0 taken 4942 times.
✓ Branch 1 taken 59 times.
|
5001 | for (int i = n - 1; i >= 0; i--) { |
| 51 | 4942 | buffer[--count[(arr[i] >> shift) & 0xFF]] = arr[i]; | |
| 52 | } | ||
| 53 |
1/2✓ Branch 1 taken 59 times.
✗ Branch 2 not taken.
|
59 | arr = buffer; |
| 54 | 118 | } | |
| 55 | |||
| 56 |
2/2✓ Branch 0 taken 30 times.
✓ Branch 1 taken 10 times.
|
40 | void RadixSortPositive(std::vector<int> &arr) { |
| 57 | 40 | int n = static_cast<int>(arr.size()); | |
| 58 |
2/2✓ Branch 0 taken 30 times.
✓ Branch 1 taken 10 times.
|
40 | if (n <= 1) { |
| 59 | 10 | return; | |
| 60 | } | ||
| 61 | |||
| 62 | int max_val = FindMaxParallel(arr, n); | ||
| 63 |
1/2✓ Branch 0 taken 30 times.
✗ Branch 1 not taken.
|
30 | if (max_val == 0) { |
| 64 | return; | ||
| 65 | } | ||
| 66 | |||
| 67 | 30 | std::vector<int> buffer(n); | |
| 68 |
3/4✓ Branch 0 taken 89 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 59 times.
✓ Branch 3 taken 30 times.
|
89 | for (int shift = 0; shift < 32 && (max_val >> shift) > 0; shift += 8) { |
| 69 |
1/2✓ Branch 1 taken 59 times.
✗ Branch 2 not taken.
|
59 | CountingPass(arr, buffer, n, shift); |
| 70 | } | ||
| 71 | } | ||
| 72 | |||
| 73 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 20 times.
|
22 | void RadixSort(std::vector<int> &arr) { |
| 74 | 22 | int n = static_cast<int>(arr.size()); | |
| 75 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 20 times.
|
22 | if (n <= 1) { |
| 76 | 2 | return; | |
| 77 | } | ||
| 78 | |||
| 79 | 20 | std::vector<int> negative; | |
| 80 | 20 | std::vector<int> positive; | |
| 81 |
2/2✓ Branch 0 taken 2480 times.
✓ Branch 1 taken 20 times.
|
2500 | for (int i = 0; i < n; i++) { |
| 82 |
2/2✓ Branch 0 taken 1210 times.
✓ Branch 1 taken 1270 times.
|
2480 | if (arr[i] < 0) { |
| 83 |
1/4✓ Branch 1 taken 1210 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
1210 | negative.push_back(-arr[i]); |
| 84 | } else { | ||
| 85 | positive.push_back(arr[i]); | ||
| 86 | } | ||
| 87 | } | ||
| 88 | |||
| 89 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | RadixSortPositive(positive); |
| 90 |
1/2✓ Branch 1 taken 20 times.
✗ Branch 2 not taken.
|
20 | RadixSortPositive(negative); |
| 91 | |||
| 92 | int idx = 0; | ||
| 93 |
2/2✓ Branch 0 taken 1210 times.
✓ Branch 1 taken 20 times.
|
1230 | for (int i = static_cast<int>(negative.size()) - 1; i >= 0; i--) { |
| 94 | 1210 | arr[idx++] = -negative[i]; | |
| 95 | } | ||
| 96 |
2/2✓ Branch 0 taken 1270 times.
✓ Branch 1 taken 20 times.
|
1290 | for (int x : positive) { |
| 97 | 1270 | arr[idx++] = x; | |
| 98 | } | ||
| 99 | } | ||
| 100 | |||
| 101 | struct MergeTask { | ||
| 102 | int lo; | ||
| 103 | int hi; | ||
| 104 | int r; | ||
| 105 | }; | ||
| 106 | |||
| 107 | 11 | std::vector<std::vector<std::pair<int, int>>> BuildMergeNetwork(int lo, int hi) { | |
| 108 | 11 | std::vector<std::vector<std::pair<int, int>>> levels; | |
| 109 |
1/2✓ Branch 1 taken 11 times.
✗ Branch 2 not taken.
|
11 | std::vector<MergeTask> current = {{lo, hi, 1}}; |
| 110 | |||
| 111 |
2/2✓ Branch 0 taken 53 times.
✓ Branch 1 taken 11 times.
|
64 | while (!current.empty()) { |
| 112 | 53 | std::vector<MergeTask> next; | |
| 113 | 53 | std::vector<std::pair<int, int>> comps; | |
| 114 | |||
| 115 |
2/2✓ Branch 0 taken 2471 times.
✓ Branch 1 taken 53 times.
|
2524 | for (const auto &[tlo, thi, tr] : current) { |
| 116 | 2471 | int step = tr * 2; | |
| 117 |
2/2✓ Branch 0 taken 1230 times.
✓ Branch 1 taken 1241 times.
|
2471 | if (step < thi - tlo) { |
| 118 |
1/2✓ Branch 1 taken 1230 times.
✗ Branch 2 not taken.
|
1230 | next.push_back({tlo, thi, step}); |
| 119 |
1/2✓ Branch 1 taken 1230 times.
✗ Branch 2 not taken.
|
1230 | next.push_back({tlo + tr, thi, step}); |
| 120 |
2/2✓ Branch 0 taken 9318 times.
✓ Branch 1 taken 1230 times.
|
10548 | for (int i = tlo + tr; i + tr <= thi; i += step) { |
| 121 |
1/2✓ Branch 1 taken 9318 times.
✗ Branch 2 not taken.
|
9318 | comps.emplace_back(i, i + tr); |
| 122 | } | ||
| 123 |
1/2✓ Branch 0 taken 1241 times.
✗ Branch 1 not taken.
|
1241 | } else if (tlo + tr <= thi) { |
| 124 |
1/4✓ Branch 1 taken 1241 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
|
1241 | comps.emplace_back(tlo, tlo + tr); |
| 125 | } | ||
| 126 | } | ||
| 127 | |||
| 128 | levels.push_back(std::move(comps)); | ||
| 129 | current = std::move(next); | ||
| 130 | } | ||
| 131 | |||
| 132 | 11 | return levels; | |
| 133 | ✗ | } | |
| 134 | |||
| 135 | 11 | void ApplyComparatorNetwork(std::vector<int> &arr, const std::vector<std::vector<std::pair<int, int>>> &levels) { | |
| 136 |
2/2✓ Branch 0 taken 53 times.
✓ Branch 1 taken 11 times.
|
64 | for (int lvl = static_cast<int>(levels.size()) - 1; lvl >= 0; lvl--) { |
| 137 | 53 | const auto &level = levels[lvl]; | |
| 138 | 53 | int level_size = static_cast<int>(level.size()); | |
| 139 | 53 | #pragma omp parallel for default(none) shared(arr, level, level_size) | |
| 140 | for (int i = 0; i < level_size; ++i) { | ||
| 141 | int aa = level[i].first; | ||
| 142 | int bb = level[i].second; | ||
| 143 | if (arr[aa] > arr[bb]) { | ||
| 144 | std::swap(arr[aa], arr[bb]); | ||
| 145 | } | ||
| 146 | } | ||
| 147 | } | ||
| 148 | 11 | } | |
| 149 | |||
| 150 | 11 | void BatcherMerge(std::vector<int> &arr, int lo, int hi) { | |
| 151 | 11 | auto levels = BuildMergeNetwork(lo, hi); | |
| 152 | 11 | ApplyComparatorNetwork(arr, levels); | |
| 153 | 11 | } | |
| 154 | |||
| 155 | ✗ | void SortSingleProcess(std::vector<int> &data, int padded, int n) { | |
| 156 | ✗ | int half = padded / 2; | |
| 157 | ✗ | std::vector<int> left(data.begin(), data.begin() + half); | |
| 158 | ✗ | std::vector<int> right(data.begin() + half, data.end()); | |
| 159 | |||
| 160 | ✗ | RadixSort(left); | |
| 161 | ✗ | RadixSort(right); | |
| 162 | |||
| 163 | std::ranges::copy(left, data.begin()); | ||
| 164 | std::ranges::copy(right, data.begin() + half); | ||
| 165 | |||
| 166 | ✗ | BatcherMerge(data, 0, padded - 1); | |
| 167 | ✗ | data.resize(n); | |
| 168 | ✗ | } | |
| 169 | |||
| 170 | 12 | int PadToPowerOfTwo(std::vector<int> &data, int n) { | |
| 171 |
2/2✓ Branch 0 taken 11 times.
✓ Branch 1 taken 1 times.
|
12 | if (n <= 1) { |
| 172 | return n; | ||
| 173 | } | ||
| 174 | int padded = 1; | ||
| 175 |
2/2✓ Branch 0 taken 53 times.
✓ Branch 1 taken 11 times.
|
64 | while (padded < n) { |
| 176 | 53 | padded *= 2; | |
| 177 | } | ||
| 178 | 11 | int max_elem = *std::ranges::max_element(data); | |
| 179 | 11 | data.resize(padded, max_elem); | |
| 180 | return padded; | ||
| 181 | } | ||
| 182 | |||
| 183 | int ComputeNumTasks(int num_ranks, int padded) { | ||
| 184 | int num_tasks = 1; | ||
| 185 |
3/4✓ Branch 0 taken 22 times.
✓ Branch 1 taken 22 times.
✓ Branch 2 taken 22 times.
✗ Branch 3 not taken.
|
44 | while (num_tasks * 2 <= num_ranks && num_tasks * 2 <= padded) { |
| 186 | num_tasks *= 2; | ||
| 187 | } | ||
| 188 | return num_tasks; | ||
| 189 | } | ||
| 190 | |||
| 191 | 11 | void MergeChunks(std::vector<int> &arr, int chunk_size, int padded) { | |
| 192 |
2/2✓ Branch 0 taken 11 times.
✓ Branch 1 taken 11 times.
|
22 | for (int step = chunk_size * 2; step <= padded; step *= 2) { |
| 193 | 11 | #pragma omp parallel for default(none) shared(step, padded, arr) | |
| 194 | for (int i = 0; i < padded; i += step) { | ||
| 195 | BatcherMerge(arr, i, i + step - 1); | ||
| 196 | } | ||
| 197 | } | ||
| 198 | 11 | } | |
| 199 | |||
| 200 | } // namespace | ||
| 201 | |||
| 202 | 24 | KarpichIBitwiseBatcherALL::KarpichIBitwiseBatcherALL(const InType &in) { | |
| 203 | SetTypeOfTask(GetStaticTypeOfTask()); | ||
| 204 | 24 | GetInput() = in; | |
| 205 | GetOutput() = 0; | ||
| 206 | 24 | } | |
| 207 | |||
| 208 | 24 | bool KarpichIBitwiseBatcherALL::ValidationImpl() { | |
| 209 | 24 | int rank = 0; | |
| 210 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 211 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 212 | 12 | return GetInput() > 0; | |
| 213 | } | ||
| 214 | return true; | ||
| 215 | } | ||
| 216 | |||
| 217 | 24 | bool KarpichIBitwiseBatcherALL::PreProcessingImpl() { | |
| 218 | 24 | int rank = 0; | |
| 219 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 220 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 221 | 12 | int n = GetInput(); | |
| 222 | 12 | data_.resize(n); | |
| 223 | 12 | std::mt19937 gen(static_cast<unsigned int>(n)); | |
| 224 | std::uniform_int_distribution<int> dist(-1000, 1000); | ||
| 225 |
2/2✓ Branch 0 taken 2426 times.
✓ Branch 1 taken 12 times.
|
2438 | for (int i = 0; i < n; i++) { |
| 226 | 2426 | data_[i] = dist(gen); | |
| 227 | } | ||
| 228 | } | ||
| 229 | 24 | return true; | |
| 230 | } | ||
| 231 | |||
| 232 | 24 | bool KarpichIBitwiseBatcherALL::RunImpl() { | |
| 233 | 24 | int rank = 0; | |
| 234 | 24 | int num_ranks = 1; | |
| 235 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 236 | 24 | MPI_Comm_size(MPI_COMM_WORLD, &num_ranks); | |
| 237 | |||
| 238 | int n = 0; | ||
| 239 | 24 | int padded = 1; | |
| 240 | |||
| 241 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 242 | 12 | n = static_cast<int>(data_.size()); | |
| 243 | 12 | padded = PadToPowerOfTwo(data_, n); | |
| 244 | } | ||
| 245 | |||
| 246 | 24 | MPI_Bcast(&padded, 1, MPI_INT, 0, MPI_COMM_WORLD); | |
| 247 | |||
| 248 |
2/2✓ Branch 0 taken 22 times.
✓ Branch 1 taken 2 times.
|
24 | if (padded <= 1) { |
| 249 | return true; | ||
| 250 | } | ||
| 251 | |||
| 252 | 22 | int num_tasks = ComputeNumTasks(num_ranks, padded); | |
| 253 | |||
| 254 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 22 times.
|
22 | if (num_tasks == 1) { |
| 255 | ✗ | if (rank == 0) { | |
| 256 | ✗ | SortSingleProcess(data_, padded, n); | |
| 257 | } | ||
| 258 | ✗ | return true; | |
| 259 | } | ||
| 260 | |||
| 261 | 22 | MPI_Comm active_comm = MPI_COMM_NULL; | |
| 262 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 22 times.
|
22 | int color = (rank < num_tasks) ? 1 : MPI_UNDEFINED; |
| 263 | 22 | MPI_Comm_split(MPI_COMM_WORLD, color, rank, &active_comm); | |
| 264 | |||
| 265 |
1/2✓ Branch 0 taken 22 times.
✗ Branch 1 not taken.
|
22 | if (active_comm != MPI_COMM_NULL) { |
| 266 | 22 | int chunk_size = padded / num_tasks; | |
| 267 | 22 | std::vector<int> local_data(chunk_size); | |
| 268 | |||
| 269 |
1/2✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
|
22 | MPI_Scatter(data_.data(), chunk_size, MPI_INT, local_data.data(), chunk_size, MPI_INT, 0, active_comm); |
| 270 | |||
| 271 |
1/2✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
|
22 | RadixSort(local_data); |
| 272 | |||
| 273 |
1/2✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
|
22 | MPI_Gather(local_data.data(), chunk_size, MPI_INT, data_.data(), chunk_size, MPI_INT, 0, active_comm); |
| 274 | |||
| 275 |
2/2✓ Branch 0 taken 11 times.
✓ Branch 1 taken 11 times.
|
22 | if (rank == 0) { |
| 276 | 11 | MergeChunks(data_, chunk_size, padded); | |
| 277 |
1/2✓ Branch 1 taken 11 times.
✗ Branch 2 not taken.
|
11 | data_.resize(n); |
| 278 | } | ||
| 279 |
1/2✓ Branch 1 taken 22 times.
✗ Branch 2 not taken.
|
22 | MPI_Comm_free(&active_comm); |
| 280 | } | ||
| 281 | |||
| 282 | return true; | ||
| 283 | } | ||
| 284 | |||
| 285 | 24 | bool KarpichIBitwiseBatcherALL::PostProcessingImpl() { | |
| 286 | 24 | int rank = 0; | |
| 287 | 24 | MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
| 288 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
|
24 | if (rank == 0) { |
| 289 | 2414 | for (int i = 1; std::cmp_less(i, data_.size()); i++) { | |
| 290 |
1/2✓ Branch 0 taken 2414 times.
✗ Branch 1 not taken.
|
2414 | if (data_[i] < data_[i - 1]) { |
| 291 | return false; | ||
| 292 | } | ||
| 293 | } | ||
| 294 | } | ||
| 295 | 24 | GetOutput() = GetInput(); | |
| 296 | 24 | return true; | |
| 297 | } | ||
| 298 | |||
| 299 | } // namespace karpich_i_bitwise_batcher | ||
| 300 |